-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op #60511
Changes from all commits
cb6f0e2
a8a694f
b15a214
f9d81a6
3a19245
716a7f5
2f53f1b
290776a
f9e9232
47dfc4b
0abba77
19cce8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h" | ||
|
||
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" | ||
#include "paddle/cinn/hlir/framework/pir/utils.h" | ||
#include "paddle/common/ddim.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include "paddle/fluid/pir/drr/api/match_context.h" | ||
#include "paddle/pir/core/builtin_dialect.h" | ||
#include "paddle/pir/pass/pass.h" | ||
#include "paddle/pir/pattern_rewrite/pattern_applicator.h" | ||
#include "paddle/pir/pattern_rewrite/pattern_match.h" | ||
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" | ||
|
||
namespace cinn { | ||
namespace dialect { | ||
namespace ir { | ||
|
||
pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, | ||
pir::Value x, | ||
pir::Value y) { | ||
pir::Value x_shape = rewriter->Build<paddle::dialect::ShapeOp>(x).out(); | ||
pir::Value y_shape = rewriter->Build<paddle::dialect::ShapeOp>(y).out(); | ||
return rewriter->Build<paddle::dialect::ShapeBroadcastOp>(x_shape, y_shape) | ||
.out(); | ||
} | ||
|
||
bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { | ||
pir::Value x = op->operand_source(0); | ||
pir::Value y = op->operand_source(1); | ||
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); | ||
{ | ||
pir::Value broadcasted_x = | ||
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out(); | ||
op->operand(0).set_source(broadcasted_x); | ||
} | ||
{ | ||
pir::Value broadcasted_y = | ||
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out(); | ||
op->operand(1).set_source(broadcasted_y); | ||
} | ||
return true; | ||
} | ||
|
||
template <typename OPTYPE> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OpTy 或者 OpType 是不是好一点 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里和 cinn 其他 pass 的写法保持一致 |
||
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> { | ||
public: | ||
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern; | ||
|
||
bool MatchAndRewrite(OPTYPE op, | ||
pir::PatternRewriter& rewriter) const override { | ||
return ProcessOp(op, &rewriter); | ||
} | ||
}; | ||
|
||
FullyInsertBroadcastPass::FullyInsertBroadcastPass() | ||
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {} | ||
|
||
pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns( | ||
pir::IrContext* context) { | ||
pir::RewritePatternSet ps(context); | ||
// elementwise ops | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>( | ||
context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context); | ||
|
||
// compare ops | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context); | ||
|
||
// bitwise ops | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context); | ||
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context); | ||
|
||
return ps; | ||
} | ||
|
||
bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const { | ||
return op->isa<pir::ModuleOp>() && op->num_regions() > 0; | ||
} | ||
|
||
} // namespace ir | ||
} // namespace dialect | ||
} // namespace cinn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/pir/pass/pass.h" | ||
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" | ||
|
||
namespace cinn { | ||
namespace dialect { | ||
namespace ir { | ||
|
||
class FullyInsertBroadcastPass : public pir::PatternRewritePass { | ||
public: | ||
FullyInsertBroadcastPass(); | ||
|
||
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override; | ||
|
||
bool CanApplyOn(pir::Operation *op) const override; | ||
}; | ||
|
||
} // namespace ir | ||
} // namespace dialect | ||
} // namespace cinn |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, | |
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp, | ||
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp, | ||
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp, | ||
paddle::dialect::Increment_Op | ||
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp | ||
#else | ||
|
||
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" | ||
|
@@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op, | |
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include "paddle/fluid/primitive/rule/vjp/vjp.h" | ||
#include "paddle/phi/api/lib/data_type_set.h" | ||
#include "paddle/phi/api/lib/utils/allocator.h" | ||
#include "paddle/phi/core/dense_tensor.h" | ||
#include "paddle/phi/core/enforce.h" | ||
|
@@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar( | |
return expected_kernel_dtype; | ||
} | ||
|
||
void ShapeBroadcastOp::Build(pir::Builder &builder, | ||
pir::OperationArgument &argument, | ||
pir::Value x_, | ||
pir::Value y_) { | ||
VLOG(4) << "Start build ShapeBroadcastOp"; | ||
|
||
VLOG(4) << "Builder construction inputs"; | ||
std::vector<pir::Value> argument_inputs = {x_, y_}; | ||
argument.AddInputs(argument_inputs); | ||
|
||
VLOG(4) << "Builder construction attributes"; | ||
|
||
VLOG(4) << "Builder construction outputs"; | ||
paddle::dialect::DenseTensorType x = | ||
x_.type().dyn_cast<paddle::dialect::DenseTensorType>(); | ||
paddle::dialect::DenseTensorType y = | ||
y_.type().dyn_cast<paddle::dialect::DenseTensorType>(); | ||
|
||
VLOG(4) << "Builder construction dense_x"; | ||
paddle::dialect::IrTensor ir_tensor_x( | ||
paddle::dialect::TransToPhiDataType(x.dtype()), | ||
x.dims(), | ||
x.data_layout(), | ||
x.lod(), | ||
x.offset()); | ||
VLOG(4) << "Builder construction meta_x"; | ||
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x); | ||
|
||
VLOG(4) << "Builder construction dense_y"; | ||
paddle::dialect::IrTensor ir_tensor_y( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里用到了 paddle dialect 下的很多数据类型,后续要如果需要移到 pir 的 shape dialect 下的话, pir 里面是不能include paddle dialect 的头文件进行依赖的,这样是不是就没办法迁移或者说要用更基础的数据类型 进行 build 才可以迁移 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处实际上模仿了 paddle dialect 下的写法,如果放到 shape dialect 下的话,这些逻辑应该就不需要了。 |
||
paddle::dialect::TransToPhiDataType(y.dtype()), | ||
y.dims(), | ||
y.data_layout(), | ||
y.lod(), | ||
y.offset()); | ||
VLOG(4) << "Builder construction meta_y"; | ||
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y); | ||
paddle::dialect::IrTensor dense_out; | ||
paddle::dialect::IrMetaTensor meta_out(&dense_out); | ||
|
||
phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out); | ||
|
||
std::vector<pir::Type> argument_outputs; | ||
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( | ||
pir::IrContext::Instance(), | ||
paddle::dialect::TransToIrDataType(dense_out.dtype()), | ||
dense_out.dims(), | ||
dense_out.layout(), | ||
dense_out.lod(), | ||
dense_out.offset()); | ||
argument_outputs.push_back(out_dense_tensor_type); | ||
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); | ||
::pir::PassStopGradientsDefaultly(argument); | ||
} | ||
|
||
namespace { | ||
|
||
void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x, | ||
const phi::MetaTensor &y, | ||
phi::MetaTensor *out) { | ||
PADDLE_ENFORCE_EQ( | ||
x.dims().size(), | ||
1, | ||
phi::errors::PreconditionNotMet( | ||
"The size %d of x.dims() must be equal to 1.", x.dims().size())); | ||
PADDLE_ENFORCE_EQ( | ||
y.dims().size(), | ||
1, | ||
phi::errors::PreconditionNotMet( | ||
"The size %d of y.dims() must be equal to 1.", y.dims().size())); | ||
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))}); | ||
// dtype need promote when meet input dtype with more precision | ||
paddle::experimental::DataTypeSet dtype_set{x.dtype()}; | ||
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype()); | ||
DataType promote_result = PromoteTypes(dtype_set); | ||
if (promote_result == DataType::UNDEFINED) { | ||
promote_result = x.dtype(); | ||
} | ||
out->set_dtype(promote_result); | ||
out->set_layout(x.layout()); | ||
out->share_lod(x); | ||
} | ||
|
||
} // namespace | ||
|
||
void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { | ||
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta); | ||
fn(infer_meta); | ||
} | ||
|
||
phi::DataType ShapeBroadcastOp::GetKernelTypeForVar( | ||
const std::string &var_name, | ||
const phi::DataType &tensor_dtype, | ||
const phi::DataType &expected_kernel_dtype) { | ||
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp"; | ||
|
||
return expected_kernel_dtype; | ||
} | ||
|
||
namespace { | ||
|
||
symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, | ||
const symbol::DimExpr &rhs) { | ||
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) { | ||
return std::max(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>()); | ||
} else if (lhs.isa<std::int64_t>()) { | ||
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs; | ||
} else if (rhs.isa<std::int64_t>()) { | ||
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs; | ||
} else { | ||
return symbol::Broadcast<symbol::DimExpr>{ | ||
symbol::List<symbol::DimExpr>{lhs, rhs}}; | ||
} | ||
LOG(FATAL) << "Dead code"; | ||
} | ||
|
||
} // namespace | ||
|
||
bool ShapeBroadcastOp::InferSymbolicShape( | ||
pir::ShapeConstraintIRAnalysis *shape_analysis) { | ||
pir::Value x = operand_source(0); | ||
pir::Value y = operand_source(1); | ||
std::string x_id = pir::GetValueId(&x); | ||
std::string y_id = pir::GetValueId(&y); | ||
|
||
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0, | ||
"x_id does not exist."); | ||
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0, | ||
"y_id does not exist."); | ||
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id); | ||
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id); | ||
IR_ENFORCE(x_data_shape.data().has_value(), | ||
"Value x comes from ShapeOp, it must have data"); | ||
IR_ENFORCE(y_data_shape.data().has_value(), | ||
"Value y comes from ShapeOp, it must have data"); | ||
const auto &x_data = x_data_shape.data().value(); | ||
const auto &y_data = y_data_shape.data().value(); | ||
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily"); | ||
|
||
std::vector<symbol::DimExpr> output_data; | ||
for (std::size_t i = 0; i < x_data.size(); ++i) { | ||
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i))); | ||
} | ||
|
||
pir::OpResult res = result(0); | ||
std::string res_id = pir::GetValueId(&res); | ||
symbol::ShapeOrDataDimExprs output_data_shape = | ||
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data); | ||
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape; | ||
return true; | ||
} | ||
|
||
} // namespace dialect | ||
} // namespace paddle | ||
|
||
|
@@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) | |
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) | ||
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) | ||
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) | ||
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp) | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
#include "paddle/fluid/framework/infershape_utils.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" | ||
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" | ||
|
@@ -554,6 +555,37 @@ class Increment_Op | |
const std::vector<std::vector<bool>> &stop_gradients); | ||
}; | ||
|
||
class IR_API ShapeBroadcastOp | ||
: public pir::Op<ShapeBroadcastOp, | ||
paddle::dialect::InferSymbolicShapeInterface, | ||
paddle::dialect::InferMetaInterface, | ||
paddle::dialect::GetKernelTypeForVarInterface> { | ||
public: | ||
using Op::Op; | ||
static const char *name() { return "pd_op.shape_broadcast"; } | ||
static constexpr const char **attributes_name = nullptr; | ||
static constexpr uint32_t attributes_num = 0; | ||
static void Build(pir::Builder &builder, // NOLINT | ||
pir::OperationArgument &argument, // NOLINT | ||
pir::Value x_, | ||
pir::Value y_); | ||
|
||
void VerifySig() {} | ||
|
||
pir::Value x() { return operand_source(0); } | ||
pir::Value y() { return operand_source(1); } | ||
pir::OpResult out() { return result(0); } | ||
|
||
static void InferMeta(phi::InferMetaContext *infer_meta); | ||
|
||
static phi::DataType GetKernelTypeForVar( | ||
const std::string &var_name, | ||
const phi::DataType &tensor_dtype, | ||
const phi::DataType &expected_kernel_dtype); | ||
Comment on lines
+581
to
+584
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 没有kernel的op应该不用加这个接口 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,后续提个 PR 删掉 |
||
|
||
bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); | ||
}; | ||
|
||
} // namespace dialect | ||
} // namespace paddle | ||
|
||
|
@@ -577,3 +609,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp) | |
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp) | ||
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp) | ||
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op) | ||
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
想请教一下这里的
{
的作用There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
局部作用域,这里不加 {} 也可以的