From 6e3ea63672ba3f342df3d14a587cb04e19c4dc2e Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Thu, 28 Dec 2023 07:37:18 +0000 Subject: [PATCH 1/4] add full_int_array --- paddle/fluid/inference/paddle_inference.map | 1 - .../interface/infer_symbolic_shape.cc | 35 ++++++++ .../operator/interface/infer_symbolic_shape.h | 3 + .../pir/transforms/shape_optimization_pass.cc | 79 ++----------------- paddle/phi/api/yaml/ops.yaml | 1 + 5 files changed, 45 insertions(+), 74 deletions(-) diff --git a/paddle/fluid/inference/paddle_inference.map b/paddle/fluid/inference/paddle_inference.map index 29f131be85e1a..01a989cc568bc 100644 --- a/paddle/fluid/inference/paddle_inference.map +++ b/paddle/fluid/inference/paddle_inference.map @@ -82,7 +82,6 @@ *Pass*; *profile*; *phi*; - *pir*; PD_*; *cinn*; local: diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 1b9ca43b7d9f1..db55d3f048cfa 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -158,6 +158,41 @@ bool Reshape_OpInferSymbolicShape( return ReshapeOpInferSymbolicShape(op, shape_analysis); } +bool FullIntArrayOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + for (auto &res : op->results()) { + std::string value_id = pir::GetValueId(&res); + std::vector dims = + common::vectorize(res.type().dyn_cast().dims()); + + std::vector shapes; + for (int64_t dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = res_dim_expr; + } else { + symbol::DimExpr res_dim_expr(dim); + dim_expr = res_dim_expr; + } + shapes.push_back(dim_expr); + } + + auto attributes = op->attributes(); + pir::Attribute attr = attributes["value"]; + const auto &vec = attr.dyn_cast().AsVector(); + + for (auto item : vec) { + int64_t i = item.dyn_cast().data(); + shapes.push_back(symbol::DimExpr(i)); + } + + symbol::ShapeOrDataDimExprs shape_data{shapes}; + shape_analysis->value_id_to_shapeordata_[value_id] = shape_data; + return true; + } +} + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index b1c72e3111df2..d9558cef89356 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -101,6 +101,9 @@ bool ReshapeOpInferSymbolicShape( bool Reshape_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool FullIntArrayOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 5c6481110034e..2dcd36664bc54 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -237,20 +237,6 @@ struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern { } }; -bool MaterializeShapeComputation(pir::ModuleOp m) { - // if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; - // TODO(zhangbopd): add rewitter pattern for reifyInferShape. - RewritePatternSet patterns(m.ir_context()); - - patterns.Add>( - patterns.ir_context()); - - IR_ENFORCE(ApplyPatternsGreedily(m, std::move(patterns)).first, - "fail to materialize shape computation\n"); - return true; -} - using PassPipelineRunner = std::function; @@ -443,20 +429,6 @@ bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { return true; } -bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { - // TODO(zhangbopd): Do some Canonicalizer. - pir::SymbolicDimMgr mgr(m); - IR_ENFORCE(mgr.Load(), - "SymbolicDimMgr Load failed in OptimizeShapeComputation."); - ShapeComputationIRAnalysis analysis(m, mgr); - if (!analysis.Run()) { - return false; - } - IR_ENFORCE(mgr.Save(), - "SymbolicDimMgr save failed in OptimizeShapeComputation."); - return true; -} - void PrintProgram(pir::ModuleOp m, std::string mgs) { std::ostringstream print_stream; print_stream << "\n\n"; @@ -514,48 +486,12 @@ void InferSymExprForAllValues(ModuleOp module_op) { for (int i = 0; i < module_op->num_regions(); i++) { for (auto& block : module_op->region(i)) { for (auto& op : block) { - if (op.num_operands() == 0) { - for (auto& res : op.results()) { - auto value_id = pir::GetValueId(&res); - - std::vector dims = common::vectorize( - res.type().dyn_cast().dims()); - - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis.GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } - - if (op.name() == "pd_op.full_int_array") { - auto attributes = op.attributes(); - auto attr = attributes["value"]; - auto arr = attr.dyn_cast(); - const auto& vec = arr.AsVector(); - for (auto item : vec) { - int64_t i = item.dyn_cast().data(); - shapes.push_back(symbol::DimExpr(i)); - } - } - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis.value_id_to_shapeordata_[value_id] = shape_data; - } - } else { - auto infer_symbolic_shape_interface = - op.dyn_cast(); - if (infer_symbolic_shape_interface) { - PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( - &shape_analysis)); - } + auto infer_symbolic_shape_interface = + op.dyn_cast(); + if (infer_symbolic_shape_interface) { + PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( + &shape_analysis)); } - DebugPrintOpInfo(&op, &shape_analysis); } } @@ -574,14 +510,11 @@ class ShapeOptimizationPass : public pir::Pass { PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); - MaterializeShapeComputation(module_op); + // Runner is for Canonicalizer. PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { return pm.Run(m.program()); }; - // if (!OptimizeShapeComputation(module_op, runner)) { - // return; - // } VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; } diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de4d700cdf80e..4cdb38f39e5fe 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1041,6 +1041,7 @@ param : [value, dtype] data_type : dtype backend : place + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : gather args : (Tensor x, Tensor index, Scalar axis=0) From 75fe3aed3c7526307fa887324c38d55805185c2e Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Fri, 29 Dec 2023 11:59:40 +0000 Subject: [PATCH 2/4] add SymbolAttribute --- paddle/fluid/inference/utils/CMakeLists.txt | 2 +- .../interface/infer_symbolic_shape.cc | 20 +++++- .../operator/interface/infer_symbolic_shape.h | 3 + .../pir/transforms/shape_optimization_pass.cc | 11 +-- paddle/phi/api/yaml/ops.yaml | 1 + .../pir/dialect/shape/ir/shape_attribute.cc | 30 ++++++++ paddle/pir/dialect/shape/ir/shape_attribute.h | 37 ++++++++++ .../shape/ir/shape_attribute_storage.h | 70 +++++++++++++++++++ paddle/pir/dialect/shape/ir/shape_dialect.cc | 13 ++++ paddle/pir/dialect/shape/ir/shape_dialect.h | 4 ++ .../symbolic/test_cinn_sub_graph_symbolic.py | 3 +- 11 files changed, 187 insertions(+), 7 deletions(-) create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute.cc create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute.h create mode 100644 paddle/pir/dialect/shape/ir/shape_attribute_storage.h diff --git a/paddle/fluid/inference/utils/CMakeLists.txt b/paddle/fluid/inference/utils/CMakeLists.txt index 3dbc06bfc11b7..976cb2dccc8c1 100644 --- a/paddle/fluid/inference/utils/CMakeLists.txt +++ b/paddle/fluid/inference/utils/CMakeLists.txt @@ -13,7 +13,7 @@ cc_library( DEPS proto_desc enforce common) cc_library(table_printer SRCS table_printer.cc) -paddle_test(test_table_printer SRCS table_printer_tester.cc) +paddle_test(test_table_printer SRCS table_printer_tester.cc DEPS pir) proto_library(shape_range_info_proto SRCS shape_range_info.proto) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index db55d3f048cfa..2e477a5949dc9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" -#include "paddle/pir/dialect/shape/ir/shape_op.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace paddle::dialect { @@ -63,6 +63,20 @@ bool Abs_OpInferSymbolicShape(pir::Operation *op, return InferSymbolicShapeAllEqualUnary(op, shape_analysis); } +bool DataOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + symbol::ShapeOrDataDimExprs sss; + + op->set_attribute( + "sym_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), sss)); + + // auto attributes = op->attributes(); + // pir::Attribute attr = attributes["shape"]; + // const auto &vec = attr.dyn_cast().AsVector(); + return true; +} + bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { return InferSymbolicShapeAllEqualUnary(op, shape_analysis); @@ -187,6 +201,10 @@ bool FullIntArrayOpInferSymbolicShape( shapes.push_back(symbol::DimExpr(i)); } + // for (auto &item : shapes) { + // VLOG(0) << symbol::ToString(item); + // } + symbol::ShapeOrDataDimExprs shape_data{shapes}; shape_analysis->value_id_to_shapeordata_[value_id] = shape_data; return true; diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index d9558cef89356..75d078986fb55 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -68,6 +68,9 @@ bool AbsOpInferSymbolicShape(pir::Operation *op, bool Abs_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool DataOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 2dcd36664bc54..4d8408c304df0 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -441,15 +441,16 @@ void PrintProgram(pir::ModuleOp m, std::string mgs) { void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - VLOG(0) << op->name() << ", num_operands: " << op->num_operands(); for (auto& res : op->results()) { auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; - print_stream << ">>>> result(" << res.index() << ") 's ID: " << value_id; + print_stream << "result(" << res.index() << ") " + << "ShapeOrData: "; + if (shape_analysis != nullptr) { auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; - print_stream << ", ShapeOrData.shape: ["; + print_stream << "shape: ["; for (auto str : shape_data.shape()) { int64_t* i = std::get_if(&str); @@ -461,7 +462,7 @@ void DebugPrintOpInfo( } } - print_stream << "], ShapeOrData.data: ["; + print_stream << "], data: ["; if (shape_data.data().has_value()) { for (auto str : shape_data.data().value()) { int64_t* i = std::get_if(&str); @@ -489,6 +490,7 @@ void InferSymExprForAllValues(ModuleOp module_op) { auto infer_symbolic_shape_interface = op.dyn_cast(); if (infer_symbolic_shape_interface) { + VLOG(0) << op.name() << " has InferSymbolicShapeInterface."; PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( &shape_analysis)); } @@ -517,6 +519,7 @@ class ShapeOptimizationPass : public pir::Pass { }; VLOG(3) << "===================== ShapeOptimizationPass Run End. " "============================="; + PrintProgram(module_op, "ShapeOptimizationPass Program"); } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 4cdb38f39e5fe..68ba253792d8a 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -664,6 +664,7 @@ param : [name, shape, dtype] data_type : dtype backend : place + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : depthwise_conv2d args : (Tensor input, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") diff --git a/paddle/pir/dialect/shape/ir/shape_attribute.cc b/paddle/pir/dialect/shape/ir/shape_attribute.cc new file mode 100644 index 0000000000000..c8751f0433ee1 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" + +namespace pir::shape { + +symbol::ShapeOrDataDimExprs SymbolAttribute::data() const { + return storage()->data(); +} + +SymbolAttribute SymbolAttribute::get(pir::IrContext* ctx, + const symbol::ShapeOrDataDimExprs& value) { + return AttributeManager::get(ctx, value); +} + +} // namespace pir::shape + +IR_DEFINE_EXPLICIT_TYPE_ID(pir::shape::SymbolAttribute) diff --git a/paddle/pir/dialect/shape/ir/shape_attribute.h b/paddle/pir/dialect/shape/ir/shape_attribute.h new file mode 100644 index 0000000000000..1eda1ab35f1a7 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute_storage.h" + +namespace pir::shape { + +class IR_API SymbolAttribute : public Attribute { + public: + using Attribute::Attribute; + + DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(SymbolAttribute, SymbolAttributeStorage); + + symbol::ShapeOrDataDimExprs data() const; + + static SymbolAttribute get(IrContext* ctx, + const symbol::ShapeOrDataDimExprs& value); +}; + +} // namespace pir::shape + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::shape::SymbolAttribute) diff --git a/paddle/pir/dialect/shape/ir/shape_attribute_storage.h b/paddle/pir/dialect/shape/ir/shape_attribute_storage.h new file mode 100644 index 0000000000000..11333f6b0d3e2 --- /dev/null +++ b/paddle/pir/dialect/shape/ir/shape_attribute_storage.h @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include "paddle/common/enforce.h" +#include "paddle/pir/core/attribute.h" +#include "paddle/pir/core/attribute_base.h" +#include "paddle/pir/core/type.h" +#include "paddle/pir/core/utils.h" +#include "paddle/pir/dialect/shape/utils/dim_expr.h" + +namespace pir::shape { + +/// +/// \brief Define Parametric AttributeStorage for SymbolAttribute. +/// +struct SymbolAttributeStorage : public AttributeStorage { + using ParamKey = symbol::ShapeOrDataDimExprs; + + explicit SymbolAttributeStorage(const ParamKey &key) : data_(key) {} + + static SymbolAttributeStorage *Construct(const ParamKey &key) { + return new SymbolAttributeStorage(key); + } + + static std::size_t HashValue(const ParamKey &key) { + std::size_t hash_value = 0; + for (size_t i = 0; i < key.shape().size(); ++i) { + hash_value = hash_combine( + hash_value, + std::hash()(symbol::ToString(key.shape()[i]))); + } + if (key.data().has_value()) { + for (size_t i = 0; i < key.data().value().size(); ++i) { + hash_value = hash_combine( + hash_value, + std::hash()(symbol::ToString(key.data().value()[i]))); + } + } + + return hash_value; + } + + bool operator==(const ParamKey &key) const { + return data_.shape() == key.shape() && data_.data() == key.data(); + } + + ParamKey data() const { return data_; } + + private: + ParamKey data_; +}; + +} // namespace pir::shape diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 0353a7610d2b3..f730256d61beb 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" #include "paddle/pir/dialect/shape/ir/shape_op.h" namespace pir::shape { @@ -33,6 +34,18 @@ void ShapeDialect::initialize() { ExtractOp, ConstantOp, IndexCastOp>(); + + RegisterAttributes(); +} + +void ShapeDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { + if (attr.isa()) { + SymbolAttribute symbol_attr = attr.dyn_cast(); + os << "(shape_or_data)"; + for (size_t i = 0; i < symbol_attr.data().shape().size(); ++i) { + os << symbol::ToString(symbol_attr.data().shape()[i]); + } + } } void ShapeDialect::PrintOperation(Operation *op, IrPrinter &printer) const { diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.h b/paddle/pir/dialect/shape/ir/shape_dialect.h index 4be71aa0127ce..33b7419c251dd 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.h +++ b/paddle/pir/dialect/shape/ir/shape_dialect.h @@ -23,7 +23,11 @@ namespace pir::shape { class IR_API ShapeDialect : public Dialect { public: explicit ShapeDialect(IrContext* context); + static const char* name() { return "shape"; } + + void PrintAttribute(pir::Attribute type, std::ostream& os) const override; + void PrintOperation(Operation* op, IrPrinter& printer) const override; // NOLINT diff --git a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py index 98968a18e228f..888e4de0eea3e 100644 --- a/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py +++ b/test/ir/pir/cinn/symbolic/test_cinn_sub_graph_symbolic.py @@ -117,9 +117,10 @@ def test_eval_symolic(self): import os is_debug = os.getenv('IS_DEBUG_DY_SHAPE') + is_debug = True if is_debug: cinn_out = self.eval_symbolic(use_cinn=True) - # print("cinn_out:", cinn_out) + print("cinn_out:", cinn_out) # dy_out = self.eval_symbolic(use_cinn=False) # print("dy_out:", dy_out) From 0252315956778cd39219a87ef7c4c26e3568e1b6 Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Tue, 2 Jan 2024 08:26:01 +0000 Subject: [PATCH 3/4] add slice --- paddle/cinn/hlir/framework/pir/utils.cc | 3 + .../interface/infer_symbolic_shape.cc | 253 +++++++++--------- .../operator/interface/infer_symbolic_shape.h | 3 + .../pir/dialect/operator/ir/op_dialect.cc | 48 +++- paddle/fluid/pir/dialect/operator/ir/ops.yaml | 1 + .../pir/transforms/shape_optimization_pass.cc | 41 +-- paddle/pir/dialect/shape/ir/shape_dialect.cc | 23 +- paddle/pir/dialect/shape/utils/dim_expr.h | 20 +- paddle/pir/dialect/shape/utils/shape_utils.h | 3 + 9 files changed, 239 insertions(+), 156 deletions(-) diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index 8833ac496e32c..a0a6f5f15614b 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -25,6 +25,7 @@ #include "paddle/phi/common/data_type.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace cinn { namespace hlir { @@ -146,6 +147,8 @@ utils::Attribute CompatibleInfo::ConvertAttribute( } else if (src_attr.isa()) { auto dtype = src_attr.dyn_cast().data(); dst_attr = phi::DataTypeToString(dtype); + } else if (src_attr.isa<::pir::shape::SymbolAttribute>()) { + auto dst_attr = src_attr.dyn_cast<::pir::shape::SymbolAttribute>().data(); } else if (src_attr.isa<::pir::ArrayAttribute>()) { auto attr_vec = src_attr.dyn_cast<::pir::ArrayAttribute>().AsVector(); if (attr_vec.size() > 0) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc index 2e477a5949dc9..4a6bf1c87d383 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_type.h" #include "paddle/pir/dialect/shape/ir/shape_attribute.h" @@ -25,113 +26,113 @@ bool InferSymbolicShapeInterface::InferSymbolicShape( } } // namespace paddle::dialect -namespace paddle::dialect { - namespace { -bool InferSymbolicShapeAllEqualUnary( +bool SameOperandsAndResultShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = - shape_analysis->value_id_to_shapeordata_[operand_source_id]; - return true; -} -bool InferSymbolicShapeAllEqualBinary( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; + + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), + operand_shape_or_data)); pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - shape_analysis->value_id_to_shapeordata_[res_id] = - shape_analysis->value_id_to_shapeordata_[operand_source_id]; + shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data; return true; } } // namespace +namespace paddle::dialect { bool AbsOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Abs_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool DataOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - symbol::ShapeOrDataDimExprs sss; + auto attributes = op->attributes(); + pir::Attribute attr = attributes["shape"]; + std::vector dims = + attr.dyn_cast().data().GetData(); + std::vector sym_dims; + for (auto dim : dims) { + symbol::DimExpr dim_expr; + if (dim == -1) { + symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = symbolic_dim_expr; + } else { + symbol::DimExpr numeric_dim_expr(dim); + dim_expr = numeric_dim_expr; + } + sym_dims.push_back(dim_expr); + } + + symbol::ShapeOrDataDimExprs shape_data{sym_dims}; op->set_attribute( - "sym_shape", - pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), sss)); + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; - // auto attributes = op->attributes(); - // pir::Attribute attr = attributes["shape"]; - // const auto &vec = attr.dyn_cast().AsVector(); return true; } bool CastOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Cast_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool ExpOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Exp_OpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualUnary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool SubtractOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualBinary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool Subtract_OpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeAllEqualBinary(op, shape_analysis); + return SameOperandsAndResultShape(op, shape_analysis); } bool ShapeOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } + symbol::ShapeOrDataDimExprs extend_shape_or_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData( + operand_shape_or_data); - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data; + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), + extend_shape_or_data)); return true; } @@ -143,27 +144,53 @@ bool ShapeSrOpInferSymbolicShape( bool StackOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; - symbol::ShapeOrDataDimExprs shape_data; - shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_id]; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims = operand_shape_or_data.data().value(); + } + // else : pir::VectorType x = + // operand_source.type().dyn_cast(); + // TODO(zhangbopd): else branch is not implemented yet. + + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_shape_or_data.data().has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); + } + + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } bool ReshapeOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source_1 = op->operand_source(1); - std::string operand_source_1_id = pir::GetValueId(&operand_source_1); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); + pir::Value operand_source_shape = op->operand_source(1); - symbol::ShapeOrDataDimExprs shape_data; + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source_shape]; + + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims = operand_shape_or_data.data().value(); + } - shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id]; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); + + pir::OpResult res0 = op->result(0); + pir::OpResult res1 = op->result(1); + shape_analysis->value_to_shape_or_data_[res0] = shape_data; + shape_analysis->value_to_shape_or_data_[res1] = + shape_analysis->value_to_shape_or_data_[operand_source_shape]; return true; } @@ -174,41 +201,32 @@ bool Reshape_OpInferSymbolicShape( bool FullIntArrayOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - for (auto &res : op->results()) { - std::string value_id = pir::GetValueId(&res); - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); - - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } + auto attributes = op->attributes(); + pir::Attribute attr = attributes["value"]; + const auto &vec = attr.dyn_cast().AsVector(); + + std::vector data; + for (auto item : vec) { + int64_t i = item.dyn_cast().data(); + data.push_back(symbol::DimExpr(i)); + } - auto attributes = op->attributes(); - pir::Attribute attr = attributes["value"]; - const auto &vec = attr.dyn_cast().AsVector(); + symbol::ShapeOrDataDimExprs shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(data); - for (auto item : vec) { - int64_t i = item.dyn_cast().data(); - shapes.push_back(symbol::DimExpr(i)); - } + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - // for (auto &item : shapes) { - // VLOG(0) << symbol::ToString(item); - // } + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; + return true; +} - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis->value_id_to_shapeordata_[value_id] = shape_data; - return true; - } +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + // TODO(zhangbopd): Not implemented yet. + return true; } } // namespace paddle::dialect @@ -216,39 +234,34 @@ namespace cinn::dialect { bool SliceOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + // TODO(zhangbopd): Not implemented yet, different from the one in paddle + // dialect. pir::Value operand_source = op->operand_source(0); - std::string operand_source_id = pir::GetValueId(&operand_source); - pir::OpResult res = op->result(0); - std::string res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs operand_shape_or_data = + shape_analysis->value_to_shape_or_data_[operand_source]; + pir::AttributeMap attributes = op->attributes(); - std::vector dims = - common::vectorize(res.type().dyn_cast().dims()); - - std::vector shapes; - for (int64_t dim : dims) { - symbol::DimExpr dim_expr; - if (dim == -1) { - symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = res_dim_expr; - } else { - symbol::DimExpr res_dim_expr(dim); - dim_expr = res_dim_expr; - } - shapes.push_back(dim_expr); - } + std::vector attr_starts = + attributes["starts"].dyn_cast().AsVector(); - // pir::AttributeMap attributes = op->attributes(); + int64_t start = attr_starts[0].dyn_cast().data(); - // auto attr_starts = - // attributes["starts"].dyn_cast().AsVector(); - // auto start = attr_starts[0].dyn_cast().data(); + std::vector out_dims; + if (operand_shape_or_data.data().has_value()) { + out_dims.push_back(operand_shape_or_data.data().value()[start]); + } - // auto attr_ends = - // attributes["ends"].dyn_cast().AsVector(); - // auto end = attr_ends[0].dyn_cast().data(); + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_shape_or_data.data().has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); + } + op->set_attribute( + "symbolic_shape", + pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data)); - symbol::ShapeOrDataDimExprs shape_data{shapes}; - shape_analysis->value_id_to_shapeordata_[res_id] = shape_data; + pir::OpResult res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h index 75d078986fb55..2d45c8607cdc6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h @@ -107,6 +107,9 @@ bool Reshape_OpInferSymbolicShape( bool FullIntArrayOpInferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis); + } // namespace paddle::dialect namespace cinn::dialect { diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6e2e105d9c18a..03440156b7e62 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -25,6 +25,7 @@ #include "paddle/pir/core/utils.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/dialect/shape/ir/shape_attribute.h" namespace paddle { namespace dialect { @@ -33,21 +34,44 @@ struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - symbol::ShapeOrDataDimExprs value_shape; + std::vector out_dims; + + // Currently for all operand : type.dims == 1u + for (size_t i = 0; i < op->num_operands(); ++i) { + auto type = + op->operand(i).type().dyn_cast(); + IR_ENFORCE(type, "Currently only support DenseTensorType."); + IR_ENFORCE(type.dims().size() == 0u, + "Currently CombineOp only support 0-d DenseTensorType for " + "InferSymbolicShape. But the dims of the %d-th " + "DenseTensorType is %d.", + i, + type.dims().size()); + } - // for (auto operand_source : op->operands_source()) { - // std::string operand_source_id = pir::GetValueId(&operand_source); - // auto source_shape_vec = - // shape_analysis->value_id_to_shapeordata_[operand_source_id]; - // for (int i = 0; i < source_shape_vec.size(); i++) { - // value_shape.second.emplace_back(source_shape_vec[i]); - // } - // } + auto operand_source_1st_data = + shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data(); + if (operand_source_1st_data.has_value()) { + for (auto operand_source : op->operands_source()) { + auto source_data = + shape_analysis->value_to_shape_or_data_[operand_source] + .data() + .value(); + out_dims.push_back(source_data[0]); + } + } - auto res = op->result(0); - auto res_id = pir::GetValueId(&res); + symbol::ShapeOrDataDimExprs shape_data{out_dims}; + if (operand_source_1st_data.has_value()) { + shape_data = + symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(shape_data); + } - shape_analysis->value_id_to_shapeordata_[res_id] = value_shape; + op->set_attribute("symbolic_shape", + pir::shape::SymbolAttribute::get( + pir::IrContext::Instance(), shape_data)); + auto res = op->result(0); + shape_analysis->value_to_shape_or_data_[res] = shape_data; return true; } diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index ec68a17c9cb13..f48e444030bf4 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -1156,6 +1156,7 @@ kernel : func : slice backward : slice_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : soft_relu args : (Tensor x, float threshold = 20.0f) diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 4d8408c304df0..384fab2c6369a 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -111,8 +111,10 @@ class InferSymbolicShapePass : public pir::Pass { if (it != infer_sym_shape_map.end()) { it->second(op, shape_analysis_); } else { - LOG(WARNING) << "[" << op.name() - << "] is not supported for infer_symbolic_shape pass."; + if (!op.HasInterface()) { + LOG(WARNING) << "[" << op.name() + << "] is not supported for infer_symbolic_shape pass."; + } } } @@ -442,41 +444,40 @@ void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { for (auto& res : op->results()) { - auto value_id = pir::GetValueId(&res); std::ostringstream print_stream; print_stream << "result(" << res.index() << ") " << "ShapeOrData: "; if (shape_analysis != nullptr) { - auto shape_data = shape_analysis->value_id_to_shapeordata_[value_id]; + auto shape_data = shape_analysis->value_to_shape_or_data_[res]; print_stream << "shape: ["; - for (auto str : shape_data.shape()) { - int64_t* i = std::get_if(&str); - std::string* s = std::get_if(&str); - if (i) { - print_stream << *i << ", "; - } else if (s) { - print_stream << *s << ", "; + for (size_t i = 0; i < shape_data.shape().size(); ++i) { + if (i != shape_data.shape().size() - 1) { + print_stream << symbol::ToString(shape_data.shape()[i]) << ","; + } else { + print_stream << symbol::ToString(shape_data.shape()[i]); } } print_stream << "], data: ["; if (shape_data.data().has_value()) { - for (auto str : shape_data.data().value()) { - int64_t* i = std::get_if(&str); - std::string* s = std::get_if(&str); - if (i) { - print_stream << *i << ", "; - } else if (s) { - print_stream << *s << ", "; + for (size_t i = 0; i < shape_data.data().value().size(); ++i) { + if (i != shape_data.data().value().size() - 1) { + print_stream << symbol::ToString(shape_data.data().value()[i]) + << ","; + } else { + print_stream << symbol::ToString(shape_data.data().value()[i]); } } + } else { + print_stream << "nullopt"; } + print_stream << "]\n"; } - VLOG(0) << print_stream.str(); + VLOG(3) << print_stream.str(); } } @@ -490,7 +491,7 @@ void InferSymExprForAllValues(ModuleOp module_op) { auto infer_symbolic_shape_interface = op.dyn_cast(); if (infer_symbolic_shape_interface) { - VLOG(0) << op.name() << " has InferSymbolicShapeInterface."; + VLOG(3) << op.name() << " has InferSymbolicShapeInterface."; PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape( &shape_analysis)); } diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index f730256d61beb..083b0d2bd37c0 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -41,10 +41,29 @@ void ShapeDialect::initialize() { void ShapeDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { if (attr.isa()) { SymbolAttribute symbol_attr = attr.dyn_cast(); - os << "(shape_or_data)"; + os << "(shape_data)"; + os << "["; for (size_t i = 0; i < symbol_attr.data().shape().size(); ++i) { - os << symbol::ToString(symbol_attr.data().shape()[i]); + if (i != symbol_attr.data().shape().size() - 1) { + os << symbol::ToString(symbol_attr.data().shape()[i]) << ","; + } else { + os << symbol::ToString(symbol_attr.data().shape()[i]); + } } + os << "]_["; + if (symbol_attr.data().data().has_value()) { + for (size_t i = 0; i < symbol_attr.data().data().value().size(); ++i) { + if (i != symbol_attr.data().data().value().size() - 1) { + os << symbol::ToString(symbol_attr.data().data().value()[i]) << ","; + } else { + os << symbol::ToString(symbol_attr.data().data().value()[i]); + } + } + } else { + os << "nullopt"; + } + + os << "]"; } } diff --git a/paddle/pir/dialect/shape/utils/dim_expr.h b/paddle/pir/dialect/shape/utils/dim_expr.h index 277a6febe66ed..adf34f45e1744 100644 --- a/paddle/pir/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/dialect/shape/utils/dim_expr.h @@ -22,9 +22,9 @@ #include #include -#include "paddle/pir/core/dll_decl.h" - #include "glog/logging.h" +#include "paddle/common/enforce.h" +#include "paddle/pir/core/dll_decl.h" namespace symbol { @@ -234,6 +234,22 @@ class ShapeOrData { return ShapeOrData(std::vector{shape}, data); } + static ShapeOrData MakeConsistentShapeOrData( + const ShapeOrData& shape_or_data) { + IR_ENFORCE(shape_or_data.data() == std::nullopt, + "Data of ShapeOrData should be nullopt"); + T shape(std::int64_t(shape_or_data.shape().size())); + return ShapeOrData(std::vector{shape}, shape_or_data.shape()); + } + + int64_t size() const { + if (data_.has_value()) { + return data_.value().size(); + } else { + return shape_.size(); + } + } + // Tensor's real shape const std::vector& shape() const { return shape_; } // Specfic for Tensor generated by shape-relevant ops diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index ac72c0bae88c7..03eee48422c6c 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -87,6 +87,9 @@ class IR_API ShapeConstraintIRAnalysis : public ShapeAnalysis { std::unordered_map value_id_to_shapeordata_; + std::unordered_map + value_to_shape_or_data_; + private: // The operation this analysis runs on. ModuleOp m_; From 094144b8b6dc3c854868aaf8800204f6b71ec213 Mon Sep 17 00:00:00 2001 From: zhangbopd <1299246947@qq.com> Date: Mon, 8 Jan 2024 11:36:03 +0000 Subject: [PATCH 4/4] change header --- paddle/fluid/pybind/pir.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 7561df7851ee9..cbe032dd80b10 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -71,6 +71,7 @@ #include "paddle/pir/core/type.h" #include "paddle/pir/core/value.h" #include "paddle/pir/dialect/control_flow/ir/cf_dialect.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" #include "paddle/pir/pass/pass_registry.h" @@ -85,7 +86,6 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/cinn/hlir/framework/pir_compiler.h" #include "paddle/fluid/pir/transforms/build_cinn_pass.h" -#include "paddle/pir/dialect/shape/ir/shape_dialect.h" #endif namespace py = pybind11;