From 06c65cee77307b362e906507a62dbab9b8a7b23a Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 30 Aug 2024 23:45:17 +0800 Subject: [PATCH 1/6] Unfinished yet --- .../multiary_infer_sym.cc | 70 +++++++++++++++++++ .../infer_symbolic_shape/multiary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 72 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 1f12f3619a38a..9e8e485f7f9f0 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2041,6 +2041,76 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( return true; } + +bool NllLossOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + const symbol::ShapeOrDataDimExprs &label_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + const std::vector &label_shape = label_shape_or_data.shape(); + PADDLE_ENFORCE_EQ(x_shape.size() == 2 || x_shape.size() == 4, + true, + phi::errors::InvalidArgument( + "The tensor rank of Input(X) must be 2 or 4.")); + PADDLE_ENFORCE_EQ( + x_shape[0], + label_shape[0], + common::errors::InvalidArgument( + "ShapeError: Expected input batch_size to match label batch_size," + "But received: the Input(x) batch_size is [%s], the Input(label) " + " batch_size is [%s].", + x_shape[0], + label_shape[0])); + if (op->operand_source(2)) { + const symbol::ShapeOrDataDimExprs &w_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(2)); + const std::vector &w_shape = w_shape_or_data.shape(); + PADDLE_ENFORCE_EQ(w_shape.size(), + 1, + common::errors::InvalidArgument( + "Input(Weight) should be a 1D tensor.")); + PADDLE_ENFORCE_EQ( + x_shape[1], + w_shape[0], + common::errors::InvalidArgument( + "Expected input tensor Weight's size should equal " + "to the first dimension of the input tensor X. But received " + "Weight's " + "size is %d, the first dimension of input X is %d", + w_shape[0], + x_shape[1])); + } + + const std::string &reduction = + op->attribute("reduction").AsString(); + + if (x_shape.size() == 2) { + if (reduction == "none") { + } else { + } + } else if (x_shape.size() == 4) { + PADDLE_ENFORCE_EQ(label_shape.size(), + 3, + common::errors::InvalidArgument( + "Expected Input(Label) dimensions=3, received %d.", + label_shape.size())); + auto input0 = x_shape[0]; + auto input2 = x_shape[2]; + auto input3 = x_shape[3]; + auto label0 = label_shape[0]; + auto label1 = label_shape[1]; + auto label2 = label_shape[2]; + PADDLE_ENFORCE_EQ( + input0 == label0 && input2 == label1 && input3 == label2, + true, + phi::errors::InvalidArgument("Input(X) tensor shape should " + "match to Input(Label) tensor " + "shape.")); + } +} + bool RoiPoolOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index f2a1becc0a4db..6336a0327fa44 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -89,6 +89,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale) OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NllLoss) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear) OP_DECLARE_INFER_SYMBOLIC_SHAPE(QuantizeLinear_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 24d04ae87b851..a3dc990275c07 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3429,6 +3429,7 @@ data_type : input optional : weight backward : nll_loss_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : nms args : (Tensor x, float threshold = 1.0f) From f446b501502a29babdf9106bf6a604e61b997e4c Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Sat, 31 Aug 2024 10:09:55 +0800 Subject: [PATCH 2/6] Finished nll loss op --- .../multiary_infer_sym.cc | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 9e8e485f7f9f0..bd78fddb36125 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2057,7 +2057,7 @@ bool NllLossOpInferSymbolicShape( PADDLE_ENFORCE_EQ( x_shape[0], label_shape[0], - common::errors::InvalidArgument( + phi::errors::InvalidArgument( "ShapeError: Expected input batch_size to match label batch_size," "But received: the Input(x) batch_size is [%s], the Input(label) " " batch_size is [%s].", @@ -2067,14 +2067,14 @@ bool NllLossOpInferSymbolicShape( const symbol::ShapeOrDataDimExprs &w_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(2)); const std::vector &w_shape = w_shape_or_data.shape(); - PADDLE_ENFORCE_EQ(w_shape.size(), - 1, - common::errors::InvalidArgument( - "Input(Weight) should be a 1D tensor.")); + PADDLE_ENFORCE_EQ( + w_shape.size(), + 1, + phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor.")); PADDLE_ENFORCE_EQ( x_shape[1], w_shape[0], - common::errors::InvalidArgument( + phi::errors::InvalidArgument( "Expected input tensor Weight's size should equal " "to the first dimension of the input tensor X. But received " "Weight's " @@ -2086,10 +2086,17 @@ bool NllLossOpInferSymbolicShape( const std::string &reduction = op->attribute("reduction").AsString(); + std::vector out_shape; if (x_shape.size() == 2) { if (reduction == "none") { + out_shape = {x_shape[0]}; } else { + out_shape = std::vector{}; } + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); } else if (x_shape.size() == 4) { PADDLE_ENFORCE_EQ(label_shape.size(), 3, @@ -2108,7 +2115,21 @@ bool NllLossOpInferSymbolicShape( phi::errors::InvalidArgument("Input(X) tensor shape should " "match to Input(Label) tensor " "shape.")); + + if (reduction == "none") { + out->set_dims({x_dims[0], x_dims[2], x_dims[3]}); + } else { + out_shape = std::vector{}; + } + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); } + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(std::vector{})}); } bool RoiPoolOpInferSymbolicShape( From 52290b7e0120de1cd5660ec8febbcc160ebb576c Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Sat, 31 Aug 2024 11:03:59 +0800 Subject: [PATCH 3/6] Finished nll loss op --- .../multiary_infer_sym.cc | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index bd78fddb36125..ecfbfff40822e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2071,16 +2071,8 @@ bool NllLossOpInferSymbolicShape( w_shape.size(), 1, phi::errors::InvalidArgument("Input(Weight) should be a 1D tensor.")); - PADDLE_ENFORCE_EQ( - x_shape[1], - w_shape[0], - phi::errors::InvalidArgument( - "Expected input tensor Weight's size should equal " - "to the first dimension of the input tensor X. But received " - "Weight's " - "size is %d, the first dimension of input X is %d", - w_shape[0], - x_shape[1])); + + infer_context->AddEqualCstr(x_shape[1], w_shape[0]); } const std::string &reduction = @@ -2100,15 +2092,15 @@ bool NllLossOpInferSymbolicShape( } else if (x_shape.size() == 4) { PADDLE_ENFORCE_EQ(label_shape.size(), 3, - common::errors::InvalidArgument( + phi::errors::InvalidArgument( "Expected Input(Label) dimensions=3, received %d.", label_shape.size())); - auto input0 = x_shape[0]; - auto input2 = x_shape[2]; - auto input3 = x_shape[3]; - auto label0 = label_shape[0]; - auto label1 = label_shape[1]; - auto label2 = label_shape[2]; + symbol::DimExpr input0 = x_shape[0]; + symbol::DimExpr input2 = x_shape[2]; + symbol::DimExpr input3 = x_shape[3]; + symbol::DimExpr label0 = label_shape[0]; + symbol::DimExpr label1 = label_shape[1]; + symbol::DimExpr label2 = label_shape[2]; PADDLE_ENFORCE_EQ( input0 == label0 && input2 == label1 && input3 == label2, true, @@ -2117,7 +2109,7 @@ bool NllLossOpInferSymbolicShape( "shape.")); if (reduction == "none") { - out->set_dims({x_dims[0], x_dims[2], x_dims[3]}); + out_shape = {x_dims[0], x_dims[2], x_dims[3]}; } else { out_shape = std::vector{}; } From 8a25ea06d6938898974bf3f5feddf97a405b2385 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Sat, 31 Aug 2024 11:24:57 +0800 Subject: [PATCH 4/6] Fixed errors --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index ecfbfff40822e..2c37be82ed031 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2109,7 +2109,7 @@ bool NllLossOpInferSymbolicShape( "shape.")); if (reduction == "none") { - out_shape = {x_dims[0], x_dims[2], x_dims[3]}; + out_shape = {x_shape[0], x_shape[2], x_shape[3]}; } else { out_shape = std::vector{}; } From 41409376270977815e437e346386393f667427d2 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Sat, 31 Aug 2024 11:40:57 +0800 Subject: [PATCH 5/6] Added return true --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 2c37be82ed031..c4f2ab82841f6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2122,6 +2122,7 @@ bool NllLossOpInferSymbolicShape( op->result(1), symbol::ShapeOrDataDimExprs{ symbol::TensorShapeOrDataDimExprs(std::vector{})}); + return true; } bool RoiPoolOpInferSymbolicShape( From 9eff13a9af99f54c625022b83db5f3c6a54a45c2 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Mon, 2 Sep 2024 15:23:30 +0800 Subject: [PATCH 6/6] Resolved suggested changes --- .../multiary_infer_sym.cc | 27 +++++-------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index c4f2ab82841f6..199c0c0d2df1c 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -2054,15 +2054,8 @@ bool NllLossOpInferSymbolicShape( true, phi::errors::InvalidArgument( "The tensor rank of Input(X) must be 2 or 4.")); - PADDLE_ENFORCE_EQ( - x_shape[0], - label_shape[0], - phi::errors::InvalidArgument( - "ShapeError: Expected input batch_size to match label batch_size," - "But received: the Input(x) batch_size is [%s], the Input(label) " - " batch_size is [%s].", - x_shape[0], - label_shape[0])); + infer_context->AddEqualCstr(x_shape[0], label_shape[0]); + if (op->operand_source(2)) { const symbol::ShapeOrDataDimExprs &w_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(2)); @@ -2095,18 +2088,10 @@ bool NllLossOpInferSymbolicShape( phi::errors::InvalidArgument( "Expected Input(Label) dimensions=3, received %d.", label_shape.size())); - symbol::DimExpr input0 = x_shape[0]; - symbol::DimExpr input2 = x_shape[2]; - symbol::DimExpr input3 = x_shape[3]; - symbol::DimExpr label0 = label_shape[0]; - symbol::DimExpr label1 = label_shape[1]; - symbol::DimExpr label2 = label_shape[2]; - PADDLE_ENFORCE_EQ( - input0 == label0 && input2 == label1 && input3 == label2, - true, - phi::errors::InvalidArgument("Input(X) tensor shape should " - "match to Input(Label) tensor " - "shape.")); + + infer_context->AddEqualCstr(x_shape[0], label_shape[0]); + infer_context->AddEqualCstr(x_shape[2], label_shape[1]); + infer_context->AddEqualCstr(x_shape[3], label_shape[2]); if (reduction == "none") { out_shape = {x_shape[0], x_shape[2], x_shape[3]};