Skip to content

Commit

Permalink
added op infer
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenStain committed Sep 16, 2024
1 parent f97db7a commit 0185022
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,38 @@ bool IndexPut_OpInferSymbolicShape(
return IndexPutOpInferSymbolicShape(op, infer_context);
}

bool LogLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const symbol::ShapeOrDataDimExprs &label_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(1));

PADDLE_ENFORCE_EQ(
input_shape_or_data.shape().size(),
2,
common::errors::InvalidArgument(
"ShapeError: input_shape_or_data should have 2 dimensions."));

PADDLE_ENFORCE_EQ(
label_shape_or_data.shape().size(),
2,
common::errors::InvalidArgument(
"ShapeError: label_shape_or_data should have 2 dimensions."));

for (int i = 0; i < 2; i++) {
infer_context->AddEqualCstr(input_shape.shape()[i], label_shape.shape()[i]);
}

std::vector<symbol::DimExpr> output_shape = {symbol::DimExpr{1}};
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(output_shape)});

return true;
}

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,81 +18,82 @@

namespace paddle::dialect {

OP_DECLARE_INFER_SYMBOLIC_SHAPE(AccuracyCheck)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Allclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ApplyPerChannelScale)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan2)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2dTranspose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ConvTranspose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Correlation)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dist)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DepthwiseConv2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dist)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dot)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dropout)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedSoftmaxMask)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherTree)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HuberLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(GridSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Histogram)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(HuberLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexAdd)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexAdd_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexPut)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexPut_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AccuracyCheck)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSelectStrided)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isclose)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(KldivLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstsq)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LuUnpack)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MarginCrossEntropy)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixNms)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MarginCrossEntropy)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatmulWithFlatten)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixNms)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mv)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(PriorBox)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullBoxSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullGpuPsSparse)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(PullSparseV2)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleaveWithTensorIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceAs)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleaveWithTensorIndex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SegmentPool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SequenceMask)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShuffleBatch)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Solve)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stft)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Swiglu)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TdmChild)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriangularSolve)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unpool3d)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unpool)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unpool3d)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(WeightDequantize)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBox)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(YoloBoxHead)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2886,6 +2886,7 @@
kernel :
func : log_loss
backward : log_loss_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : log_softmax
args : (Tensor x, int axis = -1)
Expand Down

0 comments on commit 0185022

Please sign in to comment.