-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] Add infer_symbol_shape for some ops #68166
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
const symbol::ShapeOrDataDimExprs &input_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(0)); | ||
const symbol::ShapeOrDataDimExprs &label_shape = | ||
infer_context->GetShapeOrDataForValue(op->operand_source(1)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
推荐命名input_shape_or_data, 符号推导里有shape和data区的设计,一般在取shape时才使用 _shape 后缀
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc
Outdated
Show resolved
Hide resolved
PADDLE_ENFORCE_EQ( | ||
input_shape.shape()[0], | ||
label_shape.shape()[0], | ||
common::errors::InvalidArgument( | ||
"ShapeError: The batch_size of input and label must be the same. " | ||
"But received input batch_size = %d, label batch_size = %d", | ||
input_shape.shape()[0], | ||
label_shape.shape()[0])); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dim 之间的约束应该使用 Addequalcstr()
paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h
Outdated
Show resolved
Hide resolved
|
||
// 二分类任务,label 通常是一维的(batch_size, 1),确保 input 的最后一维与 | ||
// label 的最后一维相同 | ||
if (input_shape.shape().size() == 2 && label_shape.shape().size() == 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PADDLE_ENFORCE_EQ 强制约束 input 和 label 的shape区数据size() == 2。然后写个for循环把两个dim都加上equalcstr
84a7d74
to
030b47a
Compare
030b47a
to
0185022
Compare
CI时身份信息验证失败,无法通过ci |
PR-CE-Framework单测的paddle-infer项目,日志显示 |
infer_context->AddEqualCstr(input_shape_or_data.shape()[i], | ||
label_shape_or_data.shape()[i]); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
麻烦补充一个equalcstr(input_shape[1], dimexpr(1)); 的约束,对于动态维度,能添加约束信息就尽量添加
label_shape_or_data.shape()[i]); | ||
} | ||
|
||
std::vector<symbol::DimExpr> output_shape = {symbol::DimExpr{1}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是应该是 {input_shape[0], symbol::DimExpr{1}}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
1cc9eee
to
2ec05ac
Compare
97b1076
to
a496cae
Compare
通过本地单测 |
PR-CI-Model-benchmark日志显示"fail; gpu_mem has increased, please contact RD for repair." 即便把paddle最新仓库原封不动提交跑PR也是这样; |
目前已经通过,之前导致挂掉的问题负责同学已经修复 |
symbol::DimExpr one_dim = symbol::DimExpr{1}; | ||
|
||
infer_context->AddEqualCstr(input_shape[1], one_dim); | ||
infer_context->AddEqualCstr(label_shape[1], one_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实添加一个就可以,上面已经添加了input_shape[1]和label_shape[1]的约束。可以放到下个pr修改
麻烦PR Title 中补充上log_loss op 信息 |
infer_context->GetShapeOrDataForValue( | ||
op->operand_source(1)); | ||
|
||
const auto &input_shape = input_shape_or_data.shape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以直接写成
const auto &input_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以直接写成
const auto &input_shape = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
@GoldenStain 下一个PR改一下哈
PR Category
CINN
PR Types
Improvements
Description
添加log_loss算子符号推导实现,按照字母序对原先的所有函数声明重新排序