From 620525c2ea985e9ef74ec04277f589c4e586403d Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 11:17:30 +0800 Subject: [PATCH 01/19] batch_function --- .../multiary_infer_sym.cc | 55 +++++++++++++++++-- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index c7cff62df9e2f..88be7a5a4b05b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -228,12 +228,55 @@ bool AucOpInferSymbolicShape(pir::Operation *op, return true; } -// bool BatchFcOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool BatchFCOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &w_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const auto &bias_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + + const std::vector &input_dims = input_shape_or_data.shape(); + const std::vector &w_dims = w_shape_or_data.shape(); + const std::vector &bias_dims = bias_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + input_dims.size(), + 3, + common::errors::InvalidArgument("Input of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 3, + common::errors::InvalidArgument("W of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + input_dims[0], + w_dims[0], + common::errors::InvalidArgument( + "Input.dim[0] and W.dim[0] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[1], + common::errors::InvalidArgument( + "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ(bias_dims[0], + input_dims[0], + common::errors::InvalidArgument( + "Bias.dim[0] should be same as input.dim[0].")); + PADDLE_ENFORCE_EQ(bias_dims[1], + w_dims[2], + common::errors::InvalidArgument( + "Bias.dim[1] should be same as input.dim[2].")); + + std::vector out_dims = { + input_dims[0], input_dims[1], w_dims[2]}; + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index dccefa7e149d4..bedb1e01b0b9e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -25,7 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index dd30e85fc84b0..23708ebb84fc4 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -464,6 +464,7 @@ func : batch_fc data_type: input backward: batch_fc_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bce_loss args : (Tensor input, Tensor label) From 16be43e3e351421ae1af13c99c7722f677b1fff6 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 11:27:29 +0800 Subject: [PATCH 02/19] bincount --- .../infer_symbolic_shape/binary_infer_sym.cc | 51 ++++++++++++++++--- .../infer_symbolic_shape/binary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 43c35dc905ada..6f0ac34b47205 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -173,12 +173,51 @@ bool Binomial_OpInferSymbolicShape( return BinomialOpInferSymbolicShape(op, infer_context); } -// bool BincountOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool BincountOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_dims = x_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ(x_dims.size(), + 1, + common::errors::InvalidArgument( + "The 'shape' of Input(X) must be 1-D tensor. But the " + "dimension of Input(X) is [%d]", + x_dims.size())); + + if (op->operand_source(1)) { + const auto &weights_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const std::vector &weights_dims = + weights_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ(weights_dims.size(), + 1, + common::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be 1-D tensor. " + "But the dimension of Input(Weights) is [%d]", + weights_dims.size())); + + PADDLE_ENFORCE_EQ( + weights_dims[0], + x_dims[0], + common::errors::InvalidArgument( + "The 'shape' of Input(Weights) must be equal to the 'shape' of " + "Input(X). But received: the 'shape' of Input(Weights) is [%s], " + "the 'shape' of Input(X) is [%s]", + weights_dims, + x_dims)); + } + + // Set the output shape, which is of unknown size (-1) + std::vector out_dims = {symbol::DimExpr(-1)}; + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} // bool BmmOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index e7ee88b249029..642858b185762 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -24,7 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 23708ebb84fc4..93c8929403be4 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -546,6 +546,7 @@ kernel: func: bincount optional: weights + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : binomial args : (Tensor count, Tensor prob) From 47c5d452e930d5cf6c4d11158f55208c09cace3e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 14:04:42 +0800 Subject: [PATCH 03/19] update batchfc --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 88be7a5a4b05b..67a50f753a09b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -228,7 +228,7 @@ bool AucOpInferSymbolicShape(pir::Operation *op, return true; } -bool BatchFCOpInferSymbolicShape( +bool BatchFcOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &input_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); From c4142fe0f7b8c268ba7bc14ea1694be0d65626b5 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 14:05:05 +0800 Subject: [PATCH 04/19] update batchfc --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 67a50f753a09b..2b4474b98663b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -244,21 +244,21 @@ bool BatchFcOpInferSymbolicShape( PADDLE_ENFORCE_EQ( input_dims.size(), 3, - common::errors::InvalidArgument("Input of BatchFCOp should have 3D.")); + common::errors::InvalidArgument("Input of BatchFcOp should have 3D.")); PADDLE_ENFORCE_EQ( w_dims.size(), 3, - common::errors::InvalidArgument("W of BatchFCOp should have 3D.")); + common::errors::InvalidArgument("W of BatchFcOp should have 3D.")); PADDLE_ENFORCE_EQ( input_dims[0], w_dims[0], common::errors::InvalidArgument( - "Input.dim[0] and W.dim[0] of BatchFCOp should be same.")); + "Input.dim[0] and W.dim[0] of BatchFcOp should be same.")); PADDLE_ENFORCE_EQ( input_dims[2], w_dims[1], common::errors::InvalidArgument( - "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + "Input.dim[2] and W.dim[1] of BatchFcOp should be same.")); PADDLE_ENFORCE_EQ(bias_dims[0], input_dims[0], common::errors::InvalidArgument( From 6379805cbbf402737fc2b07b236aa2154dd03765 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 15:12:34 +0800 Subject: [PATCH 05/19] update EQ --- .../infer_symbolic_shape/binary_infer_sym.cc | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 6f0ac34b47205..b46f65e65b3c0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -198,16 +198,7 @@ bool BincountOpInferSymbolicShape( "The 'shape' of Input(Weights) must be 1-D tensor. " "But the dimension of Input(Weights) is [%d]", weights_dims.size())); - - PADDLE_ENFORCE_EQ( - weights_dims[0], - x_dims[0], - common::errors::InvalidArgument( - "The 'shape' of Input(Weights) must be equal to the 'shape' of " - "Input(X). But received: the 'shape' of Input(Weights) is [%s], " - "the 'shape' of Input(X) is [%s]", - weights_dims, - x_dims)); + infer_context->AddEqualCstr(weights_dims[0], x_dims[0]); } // Set the output shape, which is of unknown size (-1) From e90891fed155c81151bb722ee70e74e82d299d8e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 15:28:27 +0800 Subject: [PATCH 06/19] update {-1} --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index b46f65e65b3c0..a5eff4871c8d4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -202,7 +202,7 @@ bool BincountOpInferSymbolicShape( } // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {symbol::DimExpr(-1)}; + std::vector out_dims = {symbol::DimExpr({-1})}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); From 8b3ad1daab974ab49d9731f120301f4a0c3e010d Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 16:25:07 +0800 Subject: [PATCH 07/19] update binary with output_size --- .../infer_symbolic_shape/binary_infer_sym.cc | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index a5eff4871c8d4..bbe93d4335c88 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -179,6 +179,20 @@ bool BincountOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_dims = x_shape_or_data.shape(); + const std::vector &input_data = x_shape_or_data.data(); + const std::vector &input_numel = input_data->numerl(); + + int64_t output_size = static_cast(*std::max_element( + input_data[0], input_data[0] + input_numel[0])) + + 1L; + + const auto &minlength_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const std::vector &minlength_data = + minlength_shape_or_data.data(); + + output_size = std::max(output_size, static_cast(minlength_data[0])); + PADDLE_ENFORCE_EQ(x_dims.size(), 1, common::errors::InvalidArgument( @@ -202,7 +216,7 @@ bool BincountOpInferSymbolicShape( } // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {symbol::DimExpr({-1})}; + std::vector out_dims = {symbol::DimExpr(output_size)}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); From 949b95c3748d899791ea679bd5c09c5955862a2b Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 16:44:27 +0800 Subject: [PATCH 08/19] undo change --- .../infer_symbolic_shape/binary_infer_sym.cc | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index bbe93d4335c88..a5eff4871c8d4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -179,20 +179,6 @@ bool BincountOpInferSymbolicShape( infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_dims = x_shape_or_data.shape(); - const std::vector &input_data = x_shape_or_data.data(); - const std::vector &input_numel = input_data->numerl(); - - int64_t output_size = static_cast(*std::max_element( - input_data[0], input_data[0] + input_numel[0])) + - 1L; - - const auto &minlength_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(2)); - const std::vector &minlength_data = - minlength_shape_or_data.data(); - - output_size = std::max(output_size, static_cast(minlength_data[0])); - PADDLE_ENFORCE_EQ(x_dims.size(), 1, common::errors::InvalidArgument( @@ -216,7 +202,7 @@ bool BincountOpInferSymbolicShape( } // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {symbol::DimExpr(output_size)}; + std::vector out_dims = {symbol::DimExpr({-1})}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); From 3f1c49febaba1e282e1590fe43cd81fb8389a974 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 17:32:58 +0800 Subject: [PATCH 09/19] add Bincount --- .../infer_symbolic_shape/binary_infer_sym.cc | 5 +- .../multiary_infer_sym.cc | 55 ++----------------- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 2 +- 4 files changed, 12 insertions(+), 52 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index a5eff4871c8d4..dd42b6b040945 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -201,8 +201,11 @@ bool BincountOpInferSymbolicShape( infer_context->AddEqualCstr(weights_dims[0], x_dims[0]); } + symbol::DimExpr out_unknown = + infer_context->GetNextSymName(); // unknown until runtime + // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {symbol::DimExpr({-1})}; + std::vector out_dims = {symbol::DimExpr({out_unknown})}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 2b4474b98663b..c7cff62df9e2f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -228,55 +228,12 @@ bool AucOpInferSymbolicShape(pir::Operation *op, return true; } -bool BatchFcOpInferSymbolicShape( - pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const auto &input_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(0)); - const auto &w_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(1)); - const auto &bias_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(2)); - - const std::vector &input_dims = input_shape_or_data.shape(); - const std::vector &w_dims = w_shape_or_data.shape(); - const std::vector &bias_dims = bias_shape_or_data.shape(); - - PADDLE_ENFORCE_EQ( - input_dims.size(), - 3, - common::errors::InvalidArgument("Input of BatchFcOp should have 3D.")); - PADDLE_ENFORCE_EQ( - w_dims.size(), - 3, - common::errors::InvalidArgument("W of BatchFcOp should have 3D.")); - PADDLE_ENFORCE_EQ( - input_dims[0], - w_dims[0], - common::errors::InvalidArgument( - "Input.dim[0] and W.dim[0] of BatchFcOp should be same.")); - PADDLE_ENFORCE_EQ( - input_dims[2], - w_dims[1], - common::errors::InvalidArgument( - "Input.dim[2] and W.dim[1] of BatchFcOp should be same.")); - PADDLE_ENFORCE_EQ(bias_dims[0], - input_dims[0], - common::errors::InvalidArgument( - "Bias.dim[0] should be same as input.dim[0].")); - PADDLE_ENFORCE_EQ(bias_dims[1], - w_dims[2], - common::errors::InvalidArgument( - "Bias.dim[1] should be same as input.dim[2].")); - - std::vector out_dims = { - input_dims[0], input_dims[1], w_dims[2]}; - - infer_context->SetShapeOrDataForValue( - op->result(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); - - return true; -} +// bool BatchFcOpInferSymbolicShape(pir::Operation *op, +// pir::InferSymbolicShapeContext +// *infer_context) { +// // pass +// return true; +// } bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index bedb1e01b0b9e..5f5ea0b001a7f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -26,7 +26,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) +// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 93c8929403be4..a9471e90b1269 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -464,7 +464,7 @@ func : batch_fc data_type: input backward: batch_fc_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface + # interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bce_loss args : (Tensor input, Tensor label) From eef3cd6c77bab38bd3d901ce6f9efba5f0a6595f Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 17:41:22 +0800 Subject: [PATCH 10/19] delete batch_fc --- paddle/phi/ops/yaml/ops.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index a9471e90b1269..90a5e15b54325 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -464,7 +464,6 @@ func : batch_fc data_type: input backward: batch_fc_grad - # interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bce_loss args : (Tensor input, Tensor label) From 5d447ea45a984cf3fb7432f378aed91a98567193 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Tue, 13 Aug 2024 13:09:02 +0800 Subject: [PATCH 11/19] update batchnormop --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 894678f2c4d1e..12cd81d053549 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -235,7 +235,7 @@ bool AucOpInferSymbolicShape(pir::Operation *op, // return true; // } -bool BatchNormOpInferSymbolicShape( +bool BatchNorm_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); @@ -334,9 +334,9 @@ bool BatchNormOpInferSymbolicShape( return true; } -bool BatchNorm_OpInferSymbolicShape( +bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return BatchNormOpInferSymbolicShape(op, infer_context); + return BatchNorm_OpInferSymbolicShape(op, infer_context); } bool BicubicInterpOpInferSymbolicShape( From 7c24bfe4e1f59feedacfd35d9f8c9abb0cd3f67e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Tue, 13 Aug 2024 13:11:25 +0800 Subject: [PATCH 12/19] update batchnormop --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 12cd81d053549..894678f2c4d1e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -235,7 +235,7 @@ bool AucOpInferSymbolicShape(pir::Operation *op, // return true; // } -bool BatchNorm_OpInferSymbolicShape( +bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); @@ -334,9 +334,9 @@ bool BatchNorm_OpInferSymbolicShape( return true; } -bool BatchNormOpInferSymbolicShape( +bool BatchNorm_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return BatchNorm_OpInferSymbolicShape(op, infer_context); + return BatchNormOpInferSymbolicShape(op, infer_context); } bool BicubicInterpOpInferSymbolicShape( From 877670c5c4412e648acbd97fcd394153e53fca92 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Tue, 13 Aug 2024 14:00:53 +0800 Subject: [PATCH 13/19] changed bn --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 894678f2c4d1e..12cd81d053549 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -235,7 +235,7 @@ bool AucOpInferSymbolicShape(pir::Operation *op, // return true; // } -bool BatchNormOpInferSymbolicShape( +bool BatchNorm_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); @@ -334,9 +334,9 @@ bool BatchNormOpInferSymbolicShape( return true; } -bool BatchNorm_OpInferSymbolicShape( +bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return BatchNormOpInferSymbolicShape(op, infer_context); + return BatchNorm_OpInferSymbolicShape(op, infer_context); } bool BicubicInterpOpInferSymbolicShape( From 4b2cc16aa1241b0fa09c8896239873b9dfb92269 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Tue, 13 Aug 2024 14:16:41 +0800 Subject: [PATCH 14/19] changed bn --- .../interface/infer_symbolic_shape/multiary_infer_sym.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index f4e0f1d49d761..095590eca991d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -25,8 +25,8 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) +// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bilinear) From 9685c7d1b89ca2bc87cf29dde6aae7ff17ba46ec Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Wed, 14 Aug 2024 17:32:58 +0800 Subject: [PATCH 15/19] unduo _ --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 12cd81d053549..894678f2c4d1e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -235,7 +235,7 @@ bool AucOpInferSymbolicShape(pir::Operation *op, // return true; // } -bool BatchNorm_OpInferSymbolicShape( +bool BatchNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); @@ -334,9 +334,9 @@ bool BatchNorm_OpInferSymbolicShape( return true; } -bool BatchNormOpInferSymbolicShape( +bool BatchNorm_OpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - return BatchNorm_OpInferSymbolicShape(op, infer_context); + return BatchNormOpInferSymbolicShape(op, infer_context); } bool BicubicInterpOpInferSymbolicShape( From a8f127777f6609dc89db15a264f98e34be6117b3 Mon Sep 17 00:00:00 2001 From: Whsjrczr <123729598+Whsjrczr@users.noreply.github.com> Date: Thu, 15 Aug 2024 18:05:25 +0800 Subject: [PATCH 16/19] Update DimExpr --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index f6d2310589649..15ae96c922b52 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -232,7 +232,7 @@ bool BincountOpInferSymbolicShape( infer_context->GetNextSymName(); // unknown until runtime // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {symbol::DimExpr({out_unknown})}; + std::vector out_dims = {{out_unknown}}; infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); From b09ad2c50f631e0f53389e6481b85034bd538e34 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Fri, 16 Aug 2024 16:29:33 +0800 Subject: [PATCH 17/19] {{out_dims}} -> out_dims --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index f6d2310589649..b9f2b8977e548 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -233,9 +233,7 @@ bool BincountOpInferSymbolicShape( // Set the output shape, which is of unknown size (-1) std::vector out_dims = {symbol::DimExpr({out_unknown})}; - infer_context->SetShapeOrDataForValue( - op->result(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + infer_context->SetShapeOrDataForValue(op->result(0), out_dims); return true; } From c3327cfd5fdc1025fc638d059771eea78f1643de Mon Sep 17 00:00:00 2001 From: Whsjrczr <123729598+Whsjrczr@users.noreply.github.com> Date: Tue, 20 Aug 2024 21:27:37 +0800 Subject: [PATCH 18/19] Update out_dims --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index f1779f3aaa6cd..8f5e164c15e83 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -232,7 +232,7 @@ bool BincountOpInferSymbolicShape( infer_context->GetNextSymName(); // unknown until runtime // Set the output shape, which is of unknown size (-1) - std::vector out_dims = {out_unknown}; + const std::vector out_dims = {out_unknown}; infer_context->SetShapeOrDataForValue(op->result(0), out_dims); return true; From c02d0c7fa9c645148c51ba75558da61ef2330769 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Wed, 21 Aug 2024 16:00:40 +0000 Subject: [PATCH 19/19] update annotation and symbol test --- .../interface/infer_symbolic_shape/binary_infer_sym.cc | 9 ++++----- test/legacy_test/test_bincount_op.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 8f5e164c15e83..6ca0058eba5a4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -228,12 +228,11 @@ bool BincountOpInferSymbolicShape( infer_context->AddEqualCstr(weights_dims[0], x_dims[0]); } - symbol::DimExpr out_unknown = - infer_context->GetNextSymName(); // unknown until runtime - - // Set the output shape, which is of unknown size (-1) + symbol::DimExpr out_unknown = infer_context->GetNextSymName(); const std::vector out_dims = {out_unknown}; - infer_context->SetShapeOrDataForValue(op->result(0), out_dims); + symbol::ShapeOrDataDimExprs output_dims{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + infer_context->SetShapeOrDataForValue(op->result(0), output_dims); return true; } diff --git a/test/legacy_test/test_bincount_op.py b/test/legacy_test/test_bincount_op.py index 2d031f16133f8..8f330e686cf67 100644 --- a/test/legacy_test/test_bincount_op.py +++ b/test/legacy_test/test_bincount_op.py @@ -156,7 +156,7 @@ def init_test_case(self): self.Out = np.bincount(self.np_input, minlength=self.minlength) def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_symbol_infer=False) class TestCase1(TestBincountOp):