Skip to content

Commit

Permalink
[PIR][DynamicShape] make shape pass default and fix some bugs (#60548)
Browse files Browse the repository at this point in the history
att, make shape pass default and fix some bugs
  • Loading branch information
lanxianghit committed Jan 6, 2024
1 parent ee3d2fc commit 7c7c5b1
Show file tree
Hide file tree
Showing 18 changed files with 152 additions and 831 deletions.
13 changes: 7 additions & 6 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <variant>
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/operation.h"
Expand All @@ -24,7 +25,7 @@
namespace cinn {
namespace dialect {

class GroupOp : public pir::Op<GroupOp> {
class IR_API GroupOp : public pir::Op<GroupOp> {
public:
using Op::Op;
static const char *name() { return "cinn_op.group"; }
Expand Down Expand Up @@ -82,7 +83,7 @@ class IR_API SplitOp : public pir::Op<SplitOp> {
void VerifySig() const {}
};

class GenerateShapeOp : public pir::Op<GenerateShapeOp> {
class IR_API GenerateShapeOp : public pir::Op<GenerateShapeOp> {
public:
using Op::Op;
static const char *name() { return "cinn_op.generate_shape"; }
Expand Down Expand Up @@ -121,7 +122,7 @@ class GenerateShapeOp : public pir::Op<GenerateShapeOp> {
} // namespace dialect
} // namespace cinn

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp);
80 changes: 69 additions & 11 deletions paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
shapes.push_back(dim_expr);
}

symbol::ShapeOrDataDimExprs shape_data{shapes};
symbol::ShapeOrDataDimExprs shape_data{
shapes,
shape_analysis->value_id_to_shapeordata_[operand_source_id].shape()};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}
Expand Down Expand Up @@ -146,9 +148,9 @@ bool ReshapeOpInferSymbolicShape(
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

symbol::ShapeOrDataDimExprs shape_data;
symbol::ShapeOrDataDimExprs shape_data{
*(shape_analysis->value_id_to_shapeordata_[operand_source_1_id].data())};

shape_data = shape_analysis->value_id_to_shapeordata_[operand_source_1_id];
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}
Expand All @@ -158,6 +160,54 @@ bool Reshape_OpInferSymbolicShape(
return ReshapeOpInferSymbolicShape(op, shape_analysis);
}

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value operand_source = op->operand_source(0);
std::string operand_source_id = pir::GetValueId(&operand_source);
pir::OpResult res = op->result(0);
std::string res_id = pir::GetValueId(&res);

std::vector<int64_t> dims =
common::vectorize(res.type().dyn_cast<pir::DenseTensorType>().dims());

std::vector<symbol::DimExpr> shapes;
for (int64_t dim : dims) {
symbol::DimExpr dim_expr;
if (dim == -1) {
symbol::DimExpr res_dim_expr(shape_analysis->GetNextSymName());
dim_expr = res_dim_expr;
} else {
symbol::DimExpr res_dim_expr(dim);
dim_expr = res_dim_expr;
}
shapes.push_back(dim_expr);
}

auto operand_source_1 = op->operand_source(1);
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
auto starts_array =
(shape_analysis->value_id_to_shapeordata_[operand_source_1_id]).data();
auto start = starts_array->at(0).Get<int64_t>();

auto operand_source_2 = op->operand_source(2);
std::string operand_source_2_id = pir::GetValueId(&operand_source_2);
auto ends_array =
(shape_analysis->value_id_to_shapeordata_[operand_source_2_id]).data();
auto end = ends_array->at(0).Get<int64_t>();

std::vector<symbol::DimExpr> data;
auto source_data =
(shape_analysis->value_id_to_shapeordata_[operand_source_id]).data();

for (int i = start; i < end; i++) {
data.emplace_back(source_data->at(i));
}

symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}

} // namespace paddle::dialect
namespace cinn::dialect {

Expand All @@ -184,17 +234,25 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
shapes.push_back(dim_expr);
}

// pir::AttributeMap attributes = op->attributes();
pir::AttributeMap attributes = op->attributes();

auto attr_starts =
attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();

// auto attr_starts =
// attributes["starts"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto start = attr_starts[0].dyn_cast<pir::Int64Attribute>().data();
auto attr_ends =
attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();

// auto attr_ends =
// attributes["ends"].dyn_cast<pir::ArrayAttribute>().AsVector();
// auto end = attr_ends[0].dyn_cast<pir::Int64Attribute>().data();
std::vector<symbol::DimExpr> data;
auto source_data =
(shape_analysis->value_id_to_shapeordata_[operand_source_id]).data();

for (int i = start; i < end; i++) {
data.emplace_back(source_data->at(i));
}

symbol::ShapeOrDataDimExprs shape_data{shapes};
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ bool ReshapeOpInferSymbolicShape(
bool Reshape_OpInferSymbolicShape(
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis);

bool SliceOpInferSymbolicShape(pir::Operation *op,
pir::ShapeConstraintIRAnalysis *shape_analysis);

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
26 changes: 15 additions & 11 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,25 @@ struct CombineOpInferSymbolicShapeInterfaceModel
: public InferSymbolicShapeInterface::Concept {
static inline bool InferSymbolicShape(
pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) {
symbol::ShapeOrDataDimExprs value_shape;

// for (auto operand_source : op->operands_source()) {
// std::string operand_source_id = pir::GetValueId(&operand_source);
// auto source_shape_vec =
// shape_analysis->value_id_to_shapeordata_[operand_source_id];
// for (int i = 0; i < source_shape_vec.size(); i++) {
// value_shape.second.emplace_back(source_shape_vec[i]);
// }
// }
std::vector<symbol::DimExpr> shapes;
std::vector<symbol::DimExpr> data;

for (auto operand_source : op->operands_source()) {
std::string operand_source_id = pir::GetValueId(&operand_source);
auto source_data_p =
shape_analysis->value_id_to_shapeordata_[operand_source_id].data();
auto source_shape_vec =
source_data_p.value_or(std::vector<symbol::DimExpr>{});
for (size_t i = 0; i < source_shape_vec.size(); i++) {
data.emplace_back(source_shape_vec.at(i));
}
}

auto res = op->result(0);
auto res_id = pir::GetValueId(&res);

shape_analysis->value_id_to_shapeordata_[res_id] = value_shape;
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
return true;
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@
kernel :
func : slice
backward : slice_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : soft_relu
args : (Tensor x, float threshold = 20.0f)
Expand Down
7 changes: 2 additions & 5 deletions paddle/fluid/pir/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
file(GLOB_RECURSE transforms_srcs "*.cc")
if(NOT WITH_CINN)
list(
REMOVE_ITEM
transforms_srcs
${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc
REMOVE_ITEM transforms_srcs ${CMAKE_CURRENT_SOURCE_DIR}/build_cinn_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_extract_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc
${CMAKE_CURRENT_SOURCE_DIR}/shape_optimization_pass.cc)
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_detector.cc)
endif()

set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir
Expand Down
Loading

0 comments on commit 7c7c5b1

Please sign in to comment.