Skip to content

Commit

Permalink
【Infer Symbolic Shape No.11】【BUAA】 Add box_coder op (#67864)
Browse files Browse the repository at this point in the history
* Draft version of BoxCoder op

* Finished box_coder op

* Small adjustments

* Finished box_coder op

* Fixed errors

* Fixed errors

* Renamed the function name

* Removed unused variable

* Rewrote part of the original code

* Renamed variables

* Resolved suggested changes
  • Loading branch information
MufanColin committed Sep 4, 2024
1 parent dd261ce commit 8a101d2
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,91 @@ bool BilinearInterpOpInferSymbolicShape(
return BicubicInterpOpInferSymbolicShape(op, infer_context);
}

bool BoxCoderOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &prior_box_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const std::vector<symbol::DimExpr> &prior_box_shape =
prior_box_shape_or_data.shape();

const symbol::ShapeOrDataDimExprs &target_box_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(2));
const std::vector<symbol::DimExpr> &target_box_shape =
target_box_shape_or_data.shape();

const std::string &code_type =
op->attribute<pir::StrAttribute>("code_type").AsString();
int axis = op->attribute<pir::Int32Attribute>("axis").data();
const std::vector<float> &variance =
paddle::dialect::details::GetVectorAttr<float>(op, "variance");

PADDLE_ENFORCE_EQ(prior_box_shape.size(),
2,
phi::errors::InvalidArgument(
"The rank of Input PriorBox in BoxCoder operator "
"must be 2. But received rank = %d",
prior_box_shape.size()));
infer_context->AddEqualCstr(prior_box_shape[1], symbol::DimExpr{4});

if (op->operand_source(1)) {
const symbol::ShapeOrDataDimExprs &prior_box_var_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));
const std::vector<symbol::DimExpr> &prior_box_var_shape =
prior_box_var_shape_or_data.shape();

PADDLE_ENFORCE_EQ(prior_box_var_shape.size(),
2,
phi::errors::InvalidArgument(
"The rank of Input(PriorBoxVar) in BoxCoder operator "
"should be 2. But received rank = %d",
prior_box_var_shape.size()));

for (size_t i = 0; i < prior_box_shape.size(); i++) {
infer_context->AddEqualCstr(prior_box_shape[i], prior_box_var_shape[i]);
}
}

if (code_type == "encode_center_size") {
PADDLE_ENFORCE_EQ(target_box_shape.size(),
2,
phi::errors::InvalidArgument(
"The rank of Input TargetBox in BoxCoder operator "
"must be 2. But received rank is %d",
target_box_shape.size()));

infer_context->AddEqualCstr(target_box_shape[1], symbol::DimExpr{4});
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(
{target_box_shape[0], prior_box_shape[0], symbol::DimExpr(4)})});
} else if (code_type == "decode_center_size") {
PADDLE_ENFORCE_EQ(target_box_shape.size(),
3,
phi::errors::InvalidArgument(
"The rank of Input TargetBox in BoxCoder operator "
"must be 3. But received rank is %d",
target_box_shape.size()));
PADDLE_ENFORCE_EQ(axis == 0 || axis == 1,
true,
phi::errors::InvalidArgument(
"axis in BoxCoder operator must be 0 or 1. "
"But received axis = %d",
axis));
if (axis == 0) {
infer_context->AddEqualCstr(target_box_shape[1], prior_box_shape[0]);
} else if (axis == 1) {
infer_context->AddEqualCstr(target_box_shape[0], prior_box_shape[0]);
}
infer_context->AddEqualCstr(target_box_shape[2], prior_box_shape[1]);
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(target_box_shape)});
}

return true;
}

bool CheckFiniteAndUnscaleOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
// Retrieve the shape information of the input list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxCoder)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CheckFiniteAndUnscale_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ChunkEval)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@
kernel :
func : box_coder
optional : prior_box_var
interfaces: paddle::dialect::InferSymbolicShapeInterface

- op : broadcast_tensors
args: (Tensor[] input)
Expand Down

0 comments on commit 8a101d2

Please sign in to comment.