diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 9ef12303a5c0b..45b8d67b19b9a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -2749,11 +2749,7 @@ bool TransposeOpInferSymbolicShape( std::vector dims; const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); - if (x_shape_or_data.data().has_value()) { - dims = x_shape_or_data.data().value(); - } else { - dims = x_shape_or_data.shape(); - } + dims = x_shape_or_data.shape(); return dims; }(); @@ -2795,6 +2791,10 @@ bool Transpose_OpInferSymbolicShape( return TransposeOpInferSymbolicShape(op, infer_context); } +bool TransLayoutOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + return TransposeOpInferSymbolicShape(op, infer_context); +} bool SqueezeOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_ENFORCE_EQ( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9600245696389..7446c48fb0559 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -141,6 +141,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Topk) OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopkV1) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Transpose) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Transpose_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TransLayout) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unbind) OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformInplace) OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniformInplace_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index c8fa663264140..7b7e09419dec8 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -4767,7 +4767,7 @@ kernel : func : transpose backward : trans_layout_grad - # interfaces : paddle::dialect::InferSymbolicShapeInterface + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : transpose args : (Tensor x, int[] perm)