Skip to content

Commit

Permalink
trying to fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
GoldenStain committed Sep 20, 2024
1 parent c4e1092 commit a496cae
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,45 @@ bool IndexPut_OpInferSymbolicShape(
return IndexPutOpInferSymbolicShape(op, infer_context);
}

bool LogLossOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(
op->operand_source(0)),
,
&label_shape_or_data =
infer_context->GetShapeOrDataForValue(
op->operand_source(1));

const auto &input_shape = input_shape_or_data.shape();
const auto &label_shape = label_shape_or_data.shape();

PADDLE_ENFORCE_EQ(
input_shape.size() == 2 && label_shape.size() == 2,
true,
common::errors::InvalidArgument(
"The rank of input and label should both be 2, but received: "
"input: %d, label: %d\n",
input_shape.size(),
label_shape.size()))

for (int i = 0; i < 2; i++) {
infer_context->AddEqualCstr(input_shape[i], label_shape[i]);
}

symbol::DimExpr one_dim = symbol::DimExpr{1};

infer_context->AddEqualCstr(input_shape[1], one_dim);
infer_context->AddEqualCstr(label_shape[1], one_dim);

infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(input_shape)});

return true;
}

} // namespace paddle::dialect

namespace cinn::dialect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSample)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(IndexSelectStrided)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(KldivLoss)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogLoss)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lstsq)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(LuUnpack)
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(MatrixRankTol)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2889,6 +2889,7 @@
kernel :
func : log_loss
backward : log_loss_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : log_softmax
args : (Tensor x, int axis = -1)
Expand Down

0 comments on commit a496cae

Please sign in to comment.