diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index 8a9acef15aa9d..77c0e0196d15a 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -16,6 +16,7 @@ #include #include "paddle/phi/core/infermeta_utils.h" #include "paddle/pir/core/builder.h" +#include "paddle/pir/core/dll_decl.h" #include "paddle/pir/core/ir_printer.h" #include "paddle/pir/core/op_base.h" #include "paddle/pir/core/operation.h" @@ -24,7 +25,7 @@ namespace cinn { namespace dialect { -class GroupOp : public pir::Op { +class IR_API GroupOp : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.group"; } @@ -82,7 +83,7 @@ class IR_API SplitOp : public pir::Op { void VerifySig() const {} }; -class GenerateShapeOp : public pir::Op { +class IR_API GenerateShapeOp : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.generate_shape"; } @@ -121,7 +122,7 @@ class GenerateShapeOp : public pir::Op { } // namespace dialect } // namespace cinn -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) -IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 1b9ca43b7d9f1..79c8e703e1184 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -116,7 +116,9 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op, shapes.push_back(dim_expr); } - symbol::ShapeOrDataDimExprs shape_data{shapes}; + symbol::ShapeOrDataDimExprs shape_data{ + shapes, + shape_analysis->value_id_to_shapeordata_[operand_source_id].shape()}; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } @@ -146,9 +148,9 @@ bool ReshapeOpInferSymbolicShape( pir::OpResult res = op->result(0); std::string res_id = pir::GetValueId(&res); - symbol::ShapeOrDataDimExprs shape_data; + symbol::ShapeOrDataDimExprs shape_data{ + *(shape_analysis->value_id_to_shapeordata_[operand_source_1_id].data())}; - shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } @@ -158,6 +160,54 @@ bool Reshape_OpInferSymbolicShape( return ReshapeOpInferSymbolicShape(op, shape_analysis); } +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + std::string operand_source_id = pir::GetValueId(&operand_source); + pir::OpResult res = op->result(0); + std::string res_id = pir::GetValueId(&res); + + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + auto operand_source_1 = op->operand_source(1); + std::string operand_source_1_id = pir::GetValueId(&operand_source_1); + auto starts_array = + (shape_analysis->value_id_to_shapeordata_[operand_source_1_id]).data(); + auto start = starts_array->at(0).Get(); + + auto operand_source_2 = op->operand_source(2); + std::string operand_source_2_id = pir::GetValueId(&operand_source_2); + auto ends_array = + (shape_analysis->value_id_to_shapeordata_[operand_source_2_id]).data(); + auto end = ends_array->at(0).Get(); + + std::vector data; + auto source_data = + (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + + for (int i = start; i < end; i++) { + data.emplace_back(source_data->at(i)); + } + + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + return true; +} + } // namespace paddle::dialect namespace cinn::dialect { @@ -184,17 +234,25 @@ bool SliceOpInferSymbolicShape(pir::Operation *op, shapes.push_back(dim_expr); } - // pir::AttributeMap attributes = op->attributes(); + pir::AttributeMap attributes = op->attributes(); + + auto attr_starts = + attributes["starts"].dyn_cast().AsVector(); + auto start = attr_starts[0].dyn_cast().data(); - // auto attr_starts = - // attributes["starts"].dyn_cast().AsVector(); - // auto start = attr_starts[0].dyn_cast().data(); + auto attr_ends = + attributes["ends"].dyn_cast().AsVector(); + auto end = attr_ends[0].dyn_cast().data(); - // auto attr_ends = - // attributes["ends"].dyn_cast().AsVector(); - // auto end = attr_ends[0].dyn_cast().data(); + std::vector data; + auto source_data = + (shape_analysis->value_id_to_shapeordata_[operand_source_id]).data(); + + for (int i = start; i < end; i++) { + data.emplace_back(source_data->at(i)); + } - symbol::ShapeOrDataDimExprs shape_data{shapes}; + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index b1c72e3111df2..fc96df40596af 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -101,6 +101,9 @@ bool ReshapeOpInferSymbolicShape( bool Reshape_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index a9129a28793b0..969edf32204bf 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -44,21 +44,25 @@ struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { - symbol::ShapeOrDataDimExprs value_shape; - - // for (auto operand_source : op->operands_source()) { - // std::string operand_source_id = pir::GetValueId(&operand_source); - // auto source_shape_vec = - // shape_analysis->value_id_to_shapeordata_[operand_source_id]; - // for (int i = 0; i < source_shape_vec.size(); i++) { - // value_shape.second.emplace_back(source_shape_vec[i]); - // } - // } + std::vector shapes; + std::vector data; + + for (auto operand_source : op->operands_source()) { + std::string operand_source_id = pir::GetValueId(&operand_source); + auto source_data_p = + shape_analysis->value_id_to_shapeordata_[operand_source_id].data(); + auto source_shape_vec = + source_data_p.value_or(std::vector{}); + for (size_t i = 0; i < source_shape_vec.size(); i++) { + data.emplace_back(source_shape_vec.at(i)); + } + } auto res = op->result(0); auto res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = value_shape; + symbol::ShapeOrDataDimExprs shape_data{shapes, data}; + shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 221aeb6c7dfa3..97fa1a6879e0a 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1179,6 +1179,7 @@ kernel : func : slice backward : slice_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : soft_relu args : (Tensor x, float threshold = 20.0f) diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index 83f9680e1cd5e..a5ffb11f0063c 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -1,12 +1,9 @@ file(GLOB_RECURSE transforms_srcs "*.cc") if(NOT WITH_CINN) list( - REMOVE_ITEM - transforms_srcs - ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc + REMOVE_ITEM transforms_srcs ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_extract_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc - ${CMAKE_CURRENT_SOURCE_DIR}/shape_optimization_pass.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc) endif() set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 5c6481110034e..1ad2700684186 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -27,117 +26,6 @@ #include "paddle/pir/pattern_rewrite/pattern_match.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -namespace { - -void InferUnaryElementwiseSymbolicShape( - const pir::Operation& op, - const std::shared_ptr& shape_analysis) { - auto input = op.operand_source(0); - auto output = op.result(0); - const auto& in_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input); - const auto& out_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(output); - pir::SymbolicDimMgr& sym_dim_mgr = shape_analysis->symbolicDimMgr(); - for (auto i = 0; i < out_sym_dims.size(); ++i) { - if (in_sym_dims[i].IsDynamic() || out_sym_dims[i].IsDynamic()) { - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims[i], out_sym_dims[i]); - } else { - // do nothing - } - } -} - -// TODO(zyfncg): support broadcast for elementwise ops. -void InferBinaryElementwiseSymbolicShape( - const pir::Operation& op, - const std::shared_ptr& shape_analysis) { - auto input0 = op.operand_source(0); - auto input1 = op.operand_source(1); - auto output = op.result(0); - const auto& in_sym_dims0 = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input0); - const auto& in_sym_dims1 = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(input1); - const auto& out_sym_dims = - shape_analysis->GetOrCreateSymbolicDimsForRankedValue(output); - pir::SymbolicDimMgr& sym_dim_mgr = shape_analysis->symbolicDimMgr(); - for (auto i = 0; i < out_sym_dims.size(); ++i) { - if (in_sym_dims0[i].IsDynamic() || in_sym_dims1[i].IsDynamic() || - out_sym_dims[i].IsDynamic()) { - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims0[i], out_sym_dims[i]); - sym_dim_mgr.MapSymbolicDimEqual(in_sym_dims1[i], out_sym_dims[i]); - } else { - // do nothing - } - } -} - -class InferSymbolicShapePass : public pir::Pass { - public: - InferSymbolicShapePass( - const std::shared_ptr& shape_analysis) - : pir::Pass("infer_symbolic_shape_pass", /*opt_level=*/1), - shape_analysis_(shape_analysis) {} - - void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "infer_symbolic_shape_pass should run on module op."); - - for (auto& op : module_op.block()) { - if (op.isa()) { - for (auto* local_op : op.dyn_cast().ops()) { - InferSymbolicShape(*local_op); - } - } else { - InferSymbolicShape(op); - } - } - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; - } - - private: - typedef void (*InferSymShapeFunc)( - const pir::Operation&, - const std::shared_ptr&); - void InferSymbolicShape(const pir::Operation& op) { - thread_local static std::unordered_map - infer_sym_shape_map(GetInferSymShapeMap()); - auto it = infer_sym_shape_map.find(op.name()); - - if (it != infer_sym_shape_map.end()) { - it->second(op, shape_analysis_); - } else { - LOG(WARNING) << "[" << op.name() - << "] is not supported for infer_symbolic_shape pass."; - } - } - - static std::unordered_map - GetInferSymShapeMap() { - return std::unordered_map{ - {paddle::dialect::ExpOp::name(), &InferUnaryElementwiseSymbolicShape}, - {paddle::dialect::SubtractOp::name(), - &InferBinaryElementwiseSymbolicShape}}; - } - - std::shared_ptr shape_analysis_; -}; - -} // namespace - -namespace pir { - -std::unique_ptr CreateInferSymbolicShapePass( - const std::shared_ptr& shape_analysis) { - return std::make_unique(shape_analysis); -} - -} // namespace pir - namespace pir { namespace { @@ -446,14 +334,12 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { // TODO(zhangbopd): Do some Canonicalizer. pir::SymbolicDimMgr mgr(m); - IR_ENFORCE(mgr.Load(), - "SymbolicDimMgr Load failed in OptimizeShapeComputation."); + ShapeComputationIRAnalysis analysis(m, mgr); if (!analysis.Run()) { return false; } - IR_ENFORCE(mgr.Save(), - "SymbolicDimMgr save failed in OptimizeShapeComputation."); + return true; } @@ -469,7 +355,7 @@ void PrintProgram(pir::ModuleOp m, std::string mgs) { void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(0) << op->name() << ", num_operands: " << op->num_operands(); + VLOG(3) << op->name() << ", num_operands: " << op->num_operands(); for (auto& res : op->results()) { auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; @@ -503,7 +389,7 @@ void DebugPrintOpInfo( } print_stream << "]\n"; } - VLOG(0) << print_stream.str(); + VLOG(3) << print_stream.str(); } } @@ -511,7 +397,7 @@ void InferSymExprForAllValues(ModuleOp module_op) { auto shape_analysis_mgr = ShapeAnalysisManager::Instance(); ShapeConstraintIRAnalysis& shape_analysis = shape_analysis_mgr.Get(module_op.program()); - for (int i = 0; i < module_op->num_regions(); i++) { + for (uint32_t i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { if (op.num_operands() == 0) { @@ -534,18 +420,21 @@ void InferSymExprForAllValues(ModuleOp module_op) { shapes.push_back(dim_expr); } + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; + if (op.name() == "pd_op.full_int_array") { + std::vector data; auto attributes = op.attributes(); auto attr = attributes["value"]; auto arr = attr.dyn_cast(); const auto& vec = arr.AsVector(); for (auto item : vec) { int64_t i = item.dyn_cast().data(); - shapes.push_back(symbol::DimExpr(i)); + data.push_back(symbol::DimExpr(i)); } + shape_analysis.value_id_to_shapeordata_[value_id].SetData(data); } - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; } } else { auto infer_symbolic_shape_interface = @@ -574,14 +463,11 @@ class ShapeOptimizationPass : public pir::Pass { PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); - MaterializeShapeComputation(module_op); // Runner is for Canonicalizer. PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { return pm.Run(m.program()); }; - // if (!OptimizeShapeComputation(module_op, runner)) { - // return; - // } + VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; } diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.h b/paddle/fluid/pir/transforms/shape_optimization_pass.h index fa192972a41b8..cbaa377157823 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.h +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.h @@ -22,10 +22,6 @@ namespace pir { class Pass; -// Apply some shape-related optimization. -IR_API std::unique_ptr CreateInferSymbolicShapePass( - const std::shared_ptr& shape_analysis); - IR_API std::unique_ptr CreateShapeOptimizationPass(); } // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index a477f42e40c48..bb0d1d230052a 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -130,6 +130,7 @@ USE_PIR_PASS(conv2d_add_act_fuse_pass); USE_PIR_PASS(fused_dot_product_attention_pass); PHI_DECLARE_bool(print_ir); +PHI_DECLARE_bool(pir_apply_shape_optimization_pass); namespace paddle { namespace pybind { @@ -1629,7 +1630,6 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT has_dynamic_shape ? std::make_shared(ctx) : nullptr; - pass_manager->AddPass(pir::CreateShapeOptimizationPass()); cinn::dialect::ir::PdOp2CinnOpConverter(&program); pass_manager->AddPass( @@ -1637,10 +1637,6 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); pass_manager->AddPass(pir::CreateBuildCinnPass()); - if (has_dynamic_shape) { - pass_manager->AddPass(pir::CreateInferSymbolicShapePass(shape_analysis)); - } - pass_manager->AddPass( cinn::dialect::ir::CreateCinnGroupLoweringPass(shape_analysis)); VLOG(4) << "has_dynamic_shape :" << has_dynamic_shape @@ -1651,8 +1647,18 @@ void AddCinnPass(std::shared_ptr &pass_manager, // NOLINT "compile PaddlePaddle with CINN")); #endif } + +void InferSymbolicShapePass( + std::shared_ptr &pass_manager, // NOLINT + Program &program) { // NOLINT + if (FLAGS_pir_apply_shape_optimization_pass) { + pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + } +} + void BindIrPass(pybind11::module *m) { m->def("add_cinn_pass", AddCinnPass); + m->def("infer_symbolic_shape_pass", InferSymbolicShapePass); py::class_> pass(*m, "Pass", diff --git a/paddle/phi/core/flags.cc b/paddle/phi/core/flags.cc index ea1af5eee4d0b..77b03f7efda2e 100644 --- a/paddle/phi/core/flags.cc +++ b/paddle/phi/core/flags.cc @@ -1413,6 +1413,19 @@ PHI_DEFINE_EXPORTED_bool(pir_apply_inplace_pass, "Whether to apply inplace pass on lowering " "::pir::Program to Kernel Dialect"); +/** + * Apply shape optimization pass to new IR FLAG + * Name: pir_apply_shape_optimization_pass + * Since Version: 3.0.0 + * Value Range: bool, default=false + * Example: + * Note: If Ture, will apply shape_optimization pass to new IR. + */ +PHI_DEFINE_EXPORTED_bool(pir_apply_shape_optimization_pass, + false, + "Whether to apply shape_optimization pass " + "to infer symbolic shape"); + PHI_DEFINE_EXPORTED_string( ir_inplace_kernel_blacklist, "", diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index a65390200cd06..4363d50769170 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -223,6 +223,8 @@ class ShapeOrData { public: explicit ShapeOrData(const std::vector& shape) : shape_(shape), data_(std::nullopt) {} + explicit ShapeOrData(const std::vector& shape, const std::vector& data) + : shape_(shape), data_(data) {} ShapeOrData() = default; ShapeOrData(const ShapeOrData&) = default; ShapeOrData(ShapeOrData&&) = default; @@ -238,11 +240,9 @@ class ShapeOrData { const std::vector& shape() const { return shape_; } // Specfic for Tensor generated by shape-relevant ops const std::optional>& data() const { return data_; } + void SetData(const std::vector& data) { data_ = data; } private: - explicit ShapeOrData(const std::vector& shape, const std::vector& data) - : shape_(shape), data_(data) {} - std::vector shape_; std::optional> data_; }; diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc index 0d8305c5c934a..a2e8f4c6ee10a 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.cc @@ -60,55 +60,6 @@ SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { symbol_table_ = SymbolTable(func); } -bool SymbolicDimMgr::Load() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - for (auto& op : *(func_op.block())) { - symbol_table_.insert(&op); - if (SymbolicDimOp sym_dim_op = op.dyn_cast()) { - symbol_dim_union_set_[sym_dim_op] = sym_dim_op; - symbol_name_set_.insert(sym_dim_op.GetSymName()); - } - } - return LoadShapeConstraintGraph(); -} - -bool SymbolicDimMgr::LoadShapeConstraintGraph() { - // TODO(zhangbopd): add more constraint function. currently, only support - // tie_product_equal. - auto constraint_vec = - symbol_table_.Lookup("tie_product_equal"); - - if (!constraint_vec.size()) return true; - - auto build_sym_product = [&](std::vector range, - SymbolicDimProduct& product) { - for (Value v : range) { - auto defining_op = v.dyn_cast().owner(); - if (auto constOp = defining_op->dyn_cast()) { - product.factor *= constOp.value().dyn_cast().data(); - continue; - } else if (auto dim_op = defining_op->dyn_cast()) { - auto sym = symbol_table_.Lookup(dim_op.GetName()); - if (!sym) return false; - product.symbols.push_back(sym); - continue; - } - return false; - } - return true; - }; - - for (auto op : constraint_vec) { - SymbolicDimProduct lhs, rhs; - if (!build_sym_product(op.lhs(), lhs) || - !build_sym_product(op.rhs(), rhs) || - !MapSymbolicDimProductEqual(lhs, rhs)) - return false; - } - return true; -} - bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct new_lhs, new_rhs; @@ -457,167 +408,4 @@ bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, return IsMultipleOfKnownSymbolicDimProductEqualPair(new_lhs, new_rhs); } -bool SymbolicDimMgr::Save() { - using Name2SymbolFn = std::function; - auto update_attrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { - std::vector new_attrs; - for (Attribute attr : attrs.AsVector()) { - auto sym = fn(attr.dyn_cast().AsString()); - IR_ENFORCE(sym); - SymbolicDimOp root = GetRootSymbolicDim(sym); - Attribute root_symbol = - StrAttribute::get(m_->ir_context(), root.GetSymName()); - new_attrs.push_back(root_symbol); - } - return ArrayAttribute::get(m_->ir_context(), new_attrs); - }; - - // TODO(zhangbopd): update attributes attached in DenseTensorType - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - auto symbolic_shape_attr = - update_attrs(attrs, [&](const std::string& name) { - return symbol_table_.Lookup(name); - }); - op.set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), - symbolic_shape_attr); - } - if (!UpdateProductEqualityMap()) { - return false; - } - std::unordered_set used_symbolic_ops; - std::vector used_symbol_names; - // TODO(zhangbopd): collect uses in value. - auto collect_used_symbols = [&](ArrayAttribute attrs) { - for (Attribute attr : attrs.AsVector()) { - auto sym = symbol_table_.Lookup( - attr.dyn_cast().AsString()); - IR_ENFORCE(sym); - if (used_symbolic_ops.insert(sym).second) - used_symbol_names.push_back(sym.GetSymName()); - } - }; - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - collect_used_symbols(attrs); - } - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - for (auto& p : symbol_dim_union_set_) { - if (!used_symbolic_ops.count(p.first)) { - func_op.block()->erase(*(p.first.operation())); - } - } - - std::vector candidates; - for (auto& outter : product_equality_map_) { - if (std::any_of(outter.first.symbols.begin(), - outter.first.symbols.end(), - [&](SymbolicDimOp sym) { - return used_symbolic_ops.count(sym) == 0; - })) - candidates.push_back(outter.first); - } - - for (auto& prod : candidates) product_equality_map_.erase(prod); - for (auto& outter : product_equality_map_) { - std::vector candidates; - for (auto& inner : outter.second) { - if (std::any_of(inner.first.symbols.begin(), - inner.first.symbols.end(), - [&](SymbolicDimOp sym) { - return used_symbolic_ops.count(sym) == 0; - })) - candidates.push_back(outter.first); - } - for (auto& prod : candidates) outter.second.erase(prod); - } - - std::sort(used_symbol_names.begin(), - used_symbol_names.end(), - [&](const std::string& lhs, const std::string& rhs) { - return CompareSymbolicDimNames(lhs, rhs); - }); - int non_const_dims_num = 0; - std::unordered_map name_mapping; - for (const auto& name : used_symbol_names) { - if (name.size() > 0 && name[0] == 'C') { - name_mapping[name] = name; - } else { - name_mapping[name] = ("S" + std::to_string(non_const_dims_num++)); - } - } - - std::unordered_map name_to_symbol; - for (SymbolicDimOp op : used_symbolic_ops) { - auto name = op.GetSymName(); - op.SetSymName(name_mapping[name]); - name_to_symbol[name] = op; - } - - for (auto& op : m_.block()) { - if (!op.HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) continue; - auto attrs = - op.attribute(SymbolicDimOp::GetSymbolicDimAttrName()); - auto symbolic_shape_attr = update_attrs( - attrs, [&](const std::string& name) { return name_to_symbol[name]; }); - op.set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), - symbolic_shape_attr); - } - - // TODO(zhangbopd): update attributes attached to values. - - return SaveShapeConstraintGraph(); -} - -bool SymbolicDimMgr::SaveShapeConstraintGraph() { - auto func_op = symbol_table_.getOp()->dyn_cast(); - IR_ENFORCE(func_op); - auto op_it = func_op.block()->rbegin(); - while (op_it != func_op.block()->rend()) { - if ((op_it->isa()) || - (op_it->isa())) - op_it++; - else - op_it = decltype(op_it)(func_op.block()->erase(*op_it)); - } - - // save product equal predicate - Builder builder = Builder(m_->ir_context(), func_op.block()); - auto build_operands = [&](const SymbolicDimProduct& prod) { - std::vector values; - - if (prod.factor != 1) { - values.push_back( - builder - .Build( - Int32Attribute::get(m_->ir_context(), prod.factor), - Int32Type::get(m_->ir_context())) - ->result(0)); - } - for (SymbolicDimOp sym : prod.symbols) { - values.push_back(builder.Build(sym.GetSymName()).out()); - } - return values; - }; - std::vector sorted_product_vec; - for (auto& p : product_equality_map_) sorted_product_vec.push_back(p.first); - std::sort(sorted_product_vec.begin(), - sorted_product_vec.end(), - CompareSymbolicDimProduct); - for (auto& x : sorted_product_vec) { - for (auto& y : sorted_product_vec) { - if (!CompareSymbolicDimProduct(x, y)) continue; - if (!product_equality_map_[x][y]) continue; - auto lhs_operands = build_operands(x); - auto rhs_operands = build_operands(y); - builder.Build(lhs_operands, rhs_operands); - } - } - return true; -} } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h index a2a67c27ff713..7797ab4f2ffb2 100644 --- a/paddle/pir/dialect/shape/utils/shape_optimization_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_optimization_utils.h @@ -65,9 +65,6 @@ class IR_API SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - // Loads pre-defined SymbolicDimOp ops from the module this mgr runs on. - bool Load(); - // Create a new symbolicDim instance owned by this mgr. SymbolicDimOp NewSymbolicDim(const std::string& name = {}); @@ -117,16 +114,11 @@ class IR_API SymbolicDimMgr { bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - // Saves the updated shape constraint IR - bool Save(); - // retuns the SymbolTable. SymbolTable& symbolTable() { return symbol_table_; } private: const std::string GetNextName(); - bool SaveShapeConstraintGraph(); - bool LoadShapeConstraintGraph(); bool UpdateProductEqualityMap(); bool IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 4beb53dde4911..574805c61f020 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -47,7 +47,6 @@ bool ShapeAnalysis::IsProductEqual( ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) : m_(m), mgr_(m) { - mgr_.Load(); for (auto& op : m.block()) { auto tie_shape_op = op.dyn_cast(); if (!tie_shape_op) continue; @@ -66,9 +65,7 @@ ShapeConstraintIRAnalysis::ShapeConstraintIRAnalysis(ModuleOp m) } } -ShapeConstraintIRAnalysis::~ShapeConstraintIRAnalysis() { - // mgr_.Save(); -} +ShapeConstraintIRAnalysis::~ShapeConstraintIRAnalysis() {} bool ShapeConstraintIRAnalysis::IsShapeEqual(Value lhs, Value rhs) { if (lhs == rhs) return true; diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index 88b51f827581c..574821ab5b342 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -548,6 +548,14 @@ def _get_scope(self, program_id=None, use_scope_cache=False): @switch_to_static_graph def _create_program(self, is_infer_mode=False): if is_infer_mode: + # TODO(lanxianghit) mv this into pass_fn + def shape_pass_fn(forward_program, backward_program): + pm = paddle.base.libpaddle.pir.PassManager() + paddle.base.libpaddle.pir.infer_symbolic_shape_pass( + pm, forward_program + ) + pm.run(forward_program) + return forward_program, backward_program def pass_fn(forward_program, backward_program): pm = paddle.base.libpaddle.pir.PassManager() @@ -560,6 +568,7 @@ def pass_fn(forward_program, backward_program): infer_program = self.origin_runable_program.clone() if self._hooker: self._hooker.after_infer(infer_program) + infer_program.apply_pir_program_pass(shape_pass_fn) infer_program.apply_pir_program_pass(pass_fn) return infer_program else: diff --git a/test/cpp/pir/cinn/adt/map_expr_test.cc b/test/cpp/pir/cinn/adt/map_expr_test.cc index 578862495d1e4..14c90fbc80ded 100644 --- a/test/cpp/pir/cinn/adt/map_expr_test.cc +++ b/test/cpp/pir/cinn/adt/map_expr_test.cc @@ -74,6 +74,11 @@ TEST(MapExpr, ElementWise_Fusion_0) { ::pir::PassManager pass_manager(ctx); auto shape_analysis = std::make_shared(ctx); + + // TODO(@jiahy0825): use CreateShapeOptimizationPass() instead of + // CreateInferSymbolicShapePass() which is a fake pass + + /* pass_manager.AddPass(::pir::CreateInferSymbolicShapePass(shape_analysis)); pass_manager.Run(&program); @@ -112,4 +117,5 @@ MapExprTest(t_var_2, t_var_1) { } )TEST"; ASSERT_EQ(Trim(map_expr_str), Trim(target_str)); + */ } diff --git a/test/cpp/pir/shape_dialect/shape_optimization_test.cc b/test/cpp/pir/shape_dialect/shape_optimization_test.cc index 63621cce181df..fb32a6f234f15 100644 --- a/test/cpp/pir/shape_dialect/shape_optimization_test.cc +++ b/test/cpp/pir/shape_dialect/shape_optimization_test.cc @@ -43,10 +43,6 @@ TEST(shape_optimization, shape_optimization_pass) { // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 + 1 funcOp == 15 Ops. EXPECT_EQ(program.block()->size(), 2u); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } TEST(shape_optimization, expand_shape_of_op_pattern) { @@ -69,10 +65,6 @@ TEST(shape_optimization, expand_shape_of_op_pattern) { pm.EnableIRPrinting(); pm.AddPass(pir::CreateShapeOptimizationPass()); pm.Run(&program); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } TEST(shape_optimization, dim_of_shaped_type_op_interface_pattern) { @@ -100,8 +92,4 @@ TEST(shape_optimization, dim_of_shaped_type_op_interface_pattern) { pm.EnableIRPrinting(); pm.AddPass(pir::CreateShapeOptimizationPass()); pm.Run(&program); - - pir::SymbolicDimMgr mgr(program.module_op()); - EXPECT_TRUE(mgr.Load()); - EXPECT_TRUE(mgr.Save()); } diff --git a/test/cpp/pir/shape_dialect/shape_struct_test.cc b/test/cpp/pir/shape_dialect/shape_struct_test.cc index d2ed8a21b4e6c..12fbb641ba90c 100644 --- a/test/cpp/pir/shape_dialect/shape_struct_test.cc +++ b/test/cpp/pir/shape_dialect/shape_struct_test.cc @@ -86,427 +86,3 @@ TEST(shape_struct_test, symbolic_dim_mgr_simple) { EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s1)); EXPECT_FALSE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_c10)); } - -TEST(shape_struct_test, symbolic_dim_mgr_complex) { - /***************************************************************/ - /* Mgr with constraintOp, and SymbolicDimProduct related func. */ - /***************************************************************/ - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - - pir::SymbolicDimMgr sym_dim_mgr(program.module_op()); - auto func_op = - sym_dim_mgr.symbolTable().getOp()->dyn_cast(); - - pir::Builder builder = pir::Builder(ctx, func_op.block()); - - pir::shape::SymbolicDimOp sym_dim_s0 = sym_dim_mgr.NewSymbolicDim("S0"); - pir::shape::SymbolicDimOp sym_dim_s1 = sym_dim_mgr.NewSymbolicDim("S1"); - pir::shape::SymbolicDimOp sym_dim_s2 = sym_dim_mgr.NewSymbolicDim("S2"); - pir::shape::SymbolicDimOp sym_dim_s3 = sym_dim_mgr.NewSymbolicDim("S3"); - pir::shape::SymbolicDimOp sym_dim_s4 = sym_dim_mgr.NewSymbolicDim("S4"); - pir::shape::SymbolicDimOp sym_dim_s5 = sym_dim_mgr.NewSymbolicDim("S5"); - pir::shape::SymbolicDimOp sym_dim_s6 = sym_dim_mgr.NewSymbolicDim("S6"); - pir::shape::SymbolicDimOp sym_dim_s7 = sym_dim_mgr.NewSymbolicDim("S7"); - pir::shape::SymbolicDimOp sym_dim_s8 = sym_dim_mgr.NewSymbolicDim("S8"); - pir::shape::SymbolicDimOp sym_dim_s9 = sym_dim_mgr.NewSymbolicDim("S9"); - pir::shape::SymbolicDimOp sym_dim_s10 = sym_dim_mgr.NewSymbolicDim("S10"); - pir::shape::SymbolicDimOp sym_dim_s11 = sym_dim_mgr.NewSymbolicDim("S11"); - pir::shape::SymbolicDimOp sym_dim_s12 = sym_dim_mgr.NewSymbolicDim("S12"); - pir::shape::SymbolicDimOp sym_dim_c10 = - sym_dim_mgr.NewConstantSymbolicDim(10); - pir::shape::SymbolicDimOp sym_dim_c20 = - sym_dim_mgr.NewConstantSymbolicDim(20); - - pir::OpResult dim_op_s0 = builder.Build("S0").out(); - pir::OpResult dim_op_s1 = builder.Build("S1").out(); - pir::OpResult dim_op_s2 = builder.Build("S2").out(); - pir::OpResult dim_op_s3 = builder.Build("S3").out(); - pir::OpResult dim_op_s4 = builder.Build("S4").out(); - pir::OpResult dim_op_s5 = builder.Build("S5").out(); - pir::OpResult dim_op_s6 = builder.Build("S6").out(); - pir::OpResult dim_op_s7 = builder.Build("S7").out(); - pir::OpResult dim_op_s8 = builder.Build("S8").out(); - pir::OpResult dim_op_s9 = builder.Build("S9").out(); - pir::OpResult dim_op_s10 = builder.Build("S10").out(); - pir::OpResult dim_op_s11 = builder.Build("S11").out(); - pir::OpResult dim_op_c10 = builder.Build("C10").out(); - pir::OpResult dim_op_c20 = builder.Build("C20").out(); - pir::OpResult constant = - builder - .Build(pir::Int32Attribute::get(ctx, 2), - pir::Int32Type::get(ctx)) - ->result(0); - - // Mark S1 == S2. - builder.Build( - 2, 2, std::vector{constant, dim_op_s1, dim_op_s2, constant}); - // Mark S0 * S1 == S2 * S3, For check S0 == S3. - builder.Build( - 2, - 2, - std::vector{dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3}); - // Mark S4 * S0 * S1 == S2 * S3 * S5, For check S4 == S5. - builder.Build( - 3, - 3, - std::vector{ - dim_op_s4, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s5}); - // For check S6 == C10 * C20. - builder.Build( - 1, 2, std::vector{dim_op_s6, dim_op_c10, dim_op_c20}); - // Mark C10 * S0 * S1 == S2 * S3 * S7, for check C10 == S7. - builder.Build( - 3, - 3, - std::vector{ - dim_op_c10, dim_op_s0, dim_op_s1, dim_op_s2, dim_op_s3, dim_op_s7}); - - // For unsimplify product case: S8 * S9 == S10 * S11 - builder.Build( - 2, - 2, - std::vector{dim_op_s8, dim_op_s9, dim_op_s10, dim_op_s11}); - - auto op = test::CreateDenseTensorOp( - ctx, {-1, -1, -1, -1, -1, -1}, {"op0_attr"}, {"op0_name"}); - auto op_ = test::CreateDenseTensorOp( - ctx, {-1, -1, -1, -1, -1, 10, 20}, {"op1_attr"}, {"op1_name"}); - pir::OpResult res = op->result(0); - pir::OpResult res_ = op_->result(0); - - builder.SetInsertionPointToBlockEnd(program.block()); - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(res); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(res_); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_s3 = pir::StrAttribute::get(ctx, "S3"); - pir::Attribute attr_s4 = pir::StrAttribute::get(ctx, "S4"); - pir::Attribute attr_s5 = pir::StrAttribute::get(ctx, "S5"); - pir::Attribute attr_s6 = pir::StrAttribute::get(ctx, "S6"); - pir::Attribute attr_s7 = pir::StrAttribute::get(ctx, "S7"); - pir::Attribute attr_s8 = pir::StrAttribute::get(ctx, "S8"); - pir::Attribute attr_s9 = pir::StrAttribute::get(ctx, "S9"); - pir::Attribute attr_s10 = pir::StrAttribute::get(ctx, "S10"); - pir::Attribute attr_s11 = pir::StrAttribute::get(ctx, "S11"); - pir::Attribute attr_c10 = pir::StrAttribute::get(ctx, "C10"); - pir::Attribute attr_c20 = pir::StrAttribute::get(ctx, "C20"); - - std::vector new_attrs1 = { - attr_s0, attr_s1, attr_s2, attr_s3, attr_s4, attr_s5}; - std::vector new_attrs2 = {attr_s6, - attr_s7, - attr_s8, - attr_s9, - attr_s10, - attr_s11, - attr_c10, - attr_c20}; - std::vector new_attrs_ref = { - attr_s0, attr_s1, attr_s1, attr_s0, attr_s2, attr_s2}; - - auto array_attr1 = pir::ArrayAttribute::get(ctx, new_attrs1); - auto array_attr2 = pir::ArrayAttribute::get(ctx, new_attrs2); - auto array_attr_ref = pir::ArrayAttribute::get(ctx, new_attrs_ref); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), array_attr2); - - EXPECT_TRUE(sym_dim_mgr.Load()); - - // For check indirect equality: S1 * S4 == S2 * S5 - pir::SymbolicDimProduct sym_dim_product_lhs1; - pir::SymbolicDimProduct sym_dim_product_rhs1; - - sym_dim_product_lhs1.symbols.push_back(sym_dim_s1); - sym_dim_product_lhs1.symbols.push_back(sym_dim_s4); - - sym_dim_product_rhs1.symbols.push_back(sym_dim_s2); - sym_dim_product_rhs1.symbols.push_back(sym_dim_s5); - - // For uncompletely simplied product check: S8 * S9 * S12 == S10 * S11 * S12 - pir::SymbolicDimProduct sym_dim_product_lhs2; - pir::SymbolicDimProduct sym_dim_product_rhs2; - - sym_dim_product_lhs2.symbols.push_back(sym_dim_s8); - sym_dim_product_lhs2.symbols.push_back(sym_dim_s9); - sym_dim_product_lhs2.symbols.push_back(sym_dim_s12); - - sym_dim_product_rhs2.symbols.push_back(sym_dim_s10); - sym_dim_product_rhs2.symbols.push_back(sym_dim_s11); - sym_dim_product_rhs2.symbols.push_back(sym_dim_s12); - - // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = - // 10} - pir::SymbolicDimProduct sym_dim_product_s7; - sym_dim_product_s7.symbols.push_back(sym_dim_s7); - pir::SymbolicDimProduct simplified_product_s7 = - sym_dim_mgr.SimplifySymbolicDimProduct(sym_dim_product_s7); - - // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z - pir::SymbolicDimProduct sym_dim_product_pair_lhs; - pir::SymbolicDimProduct sym_dim_product_pair_rhs; - pir::SymbolicDimProduct new_lhs, new_rhs; - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s4); - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s1); - sym_dim_product_pair_lhs.symbols.push_back(sym_dim_s2); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s1); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s2); - sym_dim_product_pair_rhs.symbols.push_back(sym_dim_s3); - - std::tie(new_lhs, new_rhs) = sym_dim_mgr.SimplifySymbolicDimProductPair( - sym_dim_product_pair_lhs, sym_dim_product_pair_rhs); - - // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor - // = 2 Sym = {S4}} - pir::SymbolicDimProduct sym_dim_product_div_lhs; - pir::SymbolicDimProduct sym_dim_product_div_rhs; - sym_dim_product_div_lhs.symbols.push_back(sym_dim_s4); - sym_dim_product_div_lhs.symbols.push_back(sym_dim_s1); - sym_dim_product_div_lhs.symbols.push_back(sym_dim_c20); - sym_dim_product_div_rhs.symbols.push_back(sym_dim_s1); - sym_dim_product_div_rhs.symbols.push_back(sym_dim_c10); - - pir::SymbolicDimProduct *divRes = sym_dim_mgr.SymbolicDimProductDivide( - sym_dim_product_div_lhs, sym_dim_product_div_rhs); - - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s1, sym_dim_s2)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s0, sym_dim_s3)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimEqual(sym_dim_s4, sym_dim_s5)); - EXPECT_EQ(sym_dim_s6.GetDimSize(), 200); - EXPECT_EQ(sym_dim_mgr.symbolTable().Lookup("C20"), - sym_dim_c20); - EXPECT_EQ(sym_dim_s7.GetDimSize(), sym_dim_c10.GetDimSize()); - EXPECT_EQ(simplified_product_s7.factor, 10); - EXPECT_EQ(simplified_product_s7.symbols.size(), static_cast(0)); - EXPECT_EQ(new_lhs.symbols.size(), static_cast(1)); - EXPECT_EQ(new_rhs.symbols.size(), static_cast(1)); - EXPECT_EQ(new_lhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); - EXPECT_EQ(new_rhs.symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s3)); - EXPECT_EQ(divRes->factor, 2); - EXPECT_EQ(divRes->symbols.size(), static_cast(1)); - EXPECT_EQ(divRes->symbols[0], sym_dim_mgr.GetRootSymbolicDim(sym_dim_s4)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs1, - sym_dim_product_rhs1)); - EXPECT_TRUE(sym_dim_mgr.IsSymbolicDimProductEqual(sym_dim_product_lhs2, - sym_dim_product_rhs2)); - EXPECT_TRUE(sym_dim_mgr.Save()); - - pir::SymbolicDimMgr sym_dim_mgr_new(program.module_op()); - EXPECT_TRUE(sym_dim_mgr_new.Load()); - - auto attrs = tie_shape_op1.attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName()); - EXPECT_FALSE( - sym_dim_mgr_new.symbolTable().Lookup("S7")); - EXPECT_EQ(sym_dim_mgr_new.symbolTable() - .Lookup("tie_product_equal") - .size(), - static_cast(1)); - - EXPECT_EQ(attrs.AsVector(), array_attr_ref.AsVector()); -} - -TEST(shape_struct_test, shape_analysis) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::shape::FuncOp func_op = builder.Build(); - - phi::DDim dims_D_2 = {-1, 2}; - phi::DDim dims_2_2 = {2, 2}; - phi::DDim dims_D = {-1}; - - // same shape with dynamic: value1 == value2 - auto op1 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); - pir::OpResult value1 = op1->result(0); - pir::OpResult value2 = op2->result(0); - - // same shape with static: value3 == value4 - auto op3 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); - pir::OpResult value3 = op3->result(0); - pir::OpResult value4 = op4->result(0); - - // one dimension with dynamic: value5 != value1 != value3 - auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); - pir::OpResult value5 = op5->result(0); - - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::shape::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::shape::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::shape::TieShapeOp tie_shape_op5 = - builder.Build(value5); - - builder.SetInsertionPointToBlockEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s0 = - builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s1 = - builder.Build( - "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s2 = - builder.Build( - "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2"); - - auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2}); - auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2}); - auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); - tie_shape_op3->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); - tie_shape_op4->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); - tie_shape_op5->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); - - pir::ShapeConstraintIRAnalysis shape_analysis(program.module_op()); - EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5)); - EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3)); - - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2); - - const auto &val_sym_dim1 = - shape_analysis.GetOrCreateSymbolicDimsForRankedValue(value1); - const auto &val_sym_dim2 = - shape_analysis.GetOrCreateSymbolicDimsForRankedValue(value2); - EXPECT_TRUE(shape_analysis.symbolicDimMgr().IsSymbolicDimEqual( - val_sym_dim1[0], val_sym_dim2[0])); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); -} - -TEST(shape_struct_test, shape_analysis_manager) { - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); - ctx->GetOrRegisterDialect(); - ctx->GetOrRegisterDialect(); - ::pir::Builder builder = ::pir::Builder(ctx, program.block()); - pir::shape::FuncOp func_op = builder.Build(); - - phi::DDim dims_D_2 = {-1, 2}; - phi::DDim dims_2_2 = {2, 2}; - phi::DDim dims_D = {-1}; - - // same shape with dynamic: value1 == value2 - auto op1 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op1_attr"}, {"op1_name"}); - auto op2 = - test::CreateDenseTensorOp(ctx, dims_D_2, {"op2_attr"}, {"op2_name"}); - pir::OpResult value1 = op1->result(0); - pir::OpResult value2 = op2->result(0); - - // same shape with static: value3 == value4 - auto op3 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op3_attr"}, {"op3_name"}); - auto op4 = - test::CreateDenseTensorOp(ctx, dims_2_2, {"op4_attr"}, {"op4_name"}); - pir::OpResult value3 = op3->result(0); - pir::OpResult value4 = op4->result(0); - - // one dimension with dynamic: value5 != value1 != value3 - auto op5 = test::CreateDenseTensorOp(ctx, dims_D, {"op5_attr"}, {"op5_name"}); - pir::OpResult value5 = op5->result(0); - - pir::shape::TieShapeOp tie_shape_op1 = - builder.Build(value1); - pir::shape::TieShapeOp tie_shape_op2 = - builder.Build(value2); - pir::shape::TieShapeOp tie_shape_op3 = - builder.Build(value3); - pir::shape::TieShapeOp tie_shape_op4 = - builder.Build(value4); - pir::shape::TieShapeOp tie_shape_op5 = - builder.Build(value5); - - builder.SetInsertionPointToBlockEnd(func_op.block()); - builder.Build("C2", 2, true, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s0 = - builder.Build( - "S0", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s1 = - builder.Build( - "S1", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - pir::shape::SymbolicDimOp sym_dim_s2 = - builder.Build( - "S2", pir::ShapedTypeInterface::kDynamic, false, false, true, true); - - pir::Attribute attr_s0 = pir::StrAttribute::get(ctx, "S0"); - pir::Attribute attr_s1 = pir::StrAttribute::get(ctx, "S1"); - pir::Attribute attr_s2 = pir::StrAttribute::get(ctx, "S2"); - pir::Attribute attr_c2 = pir::StrAttribute::get(ctx, "C2"); - - auto attr_op1 = pir::ArrayAttribute::get(ctx, {attr_s0, attr_c2}); - auto attr_op2 = pir::ArrayAttribute::get(ctx, {attr_s1, attr_c2}); - auto attr_op3 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op4 = pir::ArrayAttribute::get(ctx, {attr_c2, attr_c2}); - auto attr_op5 = pir::ArrayAttribute::get(ctx, {attr_s2}); - - tie_shape_op1->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op1); - tie_shape_op2->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op2); - tie_shape_op3->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op3); - tie_shape_op4->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op4); - tie_shape_op5->set_attribute( - pir::shape::SymbolicDimOp::GetSymbolicDimAttrName(), attr_op5); - - auto shape_analysis_mgr = pir::ShapeAnalysisManager::Instance(); - pir::ShapeConstraintIRAnalysis &shape_analysis = - shape_analysis_mgr.Get(&program); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value3, value4)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value3)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value3, value5)); - EXPECT_TRUE(shape_analysis.IsProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shape_analysis.IsSameNumElements(value4, value3)); - - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s1); - shape_analysis.symbolicDimMgr().MapSymbolicDimEqual(sym_dim_s0, sym_dim_s2); - - EXPECT_TRUE(shape_analysis.IsShapeEqual(value1, value2)); - EXPECT_FALSE(shape_analysis.IsShapeEqual(value1, value5)); -}