Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op #60511

Merged
merged 12 commits into from
Jan 4, 2024
9 changes: 9 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,13 @@ if(NOT CINN_ONLY)
cinn_op_dialect
op_dialect_vjp)

cinn_cc_library(
fully_insert_broadcast_pass
SRCS
fully_insert_broadcast_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp)

endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace cinn {
namespace dialect {
namespace ir {

pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
pir::Value x,
pir::Value y) {
pir::Value x_shape = rewriter->Build<paddle::dialect::ShapeOp>(x).out();
pir::Value y_shape = rewriter->Build<paddle::dialect::ShapeOp>(y).out();
return rewriter->Build<paddle::dialect::ShapeBroadcastOp>(x_shape, y_shape)
.out();
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

想请教一下这里的 {的作用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

局部作用域,这里不加 {} 也可以的

pir::Value broadcasted_x =
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out();
op->operand(0).set_source(broadcasted_x);
}
{
pir::Value broadcasted_y =
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out();
op->operand(1).set_source(broadcasted_y);
}
return true;
}

template <typename OPTYPE>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpTy 或者 OpType 是不是好一点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和 cinn 其他 pass 的写法保持一致

class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;

bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
return ProcessOp(op, &rewriter);
}
};

FullyInsertBroadcastPass::FullyInsertBroadcastPass()
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}

pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}

bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace cinn {
namespace dialect {
namespace ir {

class FullyInsertBroadcastPass : public pir::PatternRewritePass {
public:
FullyInsertBroadcastPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;
};

} // namespace ir
} // namespace dialect
} // namespace cinn
156 changes: 155 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand All @@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/data_type_set.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
return expected_kernel_dtype;
}

void ShapeBroadcastOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
pir::Value x_,
pir::Value y_) {
VLOG(4) << "Start build ShapeBroadcastOp";

VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, y_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::DenseTensorType y =
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();

VLOG(4) << "Builder construction dense_x";
paddle::dialect::IrTensor ir_tensor_x(
paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset());
VLOG(4) << "Builder construction meta_x";
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);

VLOG(4) << "Builder construction dense_y";
paddle::dialect::IrTensor ir_tensor_y(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用到了 paddle dialect 下的很多数据类型,后续要如果需要移到 pir 的 shape dialect 下的话, pir 里面是不能include paddle dialect 的头文件进行依赖的,这样是不是就没办法迁移或者说要用更基础的数据类型 进行 build 才可以迁移

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处实际上模仿了 paddle dialect 下的写法,如果放到 shape dialect 下的话,这些逻辑应该就不需要了。

paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset());
VLOG(4) << "Builder construction meta_y";
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.dims(),
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

namespace {

void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x,
const phi::MetaTensor &y,
phi::MetaTensor *out) {
PADDLE_ENFORCE_EQ(
x.dims().size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of x.dims() must be equal to 1.", x.dims().size()));
PADDLE_ENFORCE_EQ(
y.dims().size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of y.dims() must be equal to 1.", y.dims().size()));
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))});
// dtype need promote when meet input dtype with more precision
paddle::experimental::DataTypeSet dtype_set{x.dtype()};
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype());
DataType promote_result = PromoteTypes(dtype_set);
if (promote_result == DataType::UNDEFINED) {
promote_result = x.dtype();
}
out->set_dtype(promote_result);
out->set_layout(x.layout());
out->share_lod(x);
}

} // namespace

void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta);
fn(infer_meta);
}

phi::DataType ShapeBroadcastOp::GetKernelTypeForVar(
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype) {
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp";

return expected_kernel_dtype;
}

namespace {

symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
const symbol::DimExpr &rhs) {
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
return std::max(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
} else if (lhs.isa<std::int64_t>()) {
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
} else if (rhs.isa<std::int64_t>()) {
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
} else {
return symbol::Broadcast<symbol::DimExpr>{
symbol::List<symbol::DimExpr>{lhs, rhs}};
}
LOG(FATAL) << "Dead code";
}

} // namespace

bool ShapeBroadcastOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value x = operand_source(0);
pir::Value y = operand_source(1);
std::string x_id = pir::GetValueId(&x);
std::string y_id = pir::GetValueId(&y);

IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
"x_id does not exist.");
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
"y_id does not exist.");
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
IR_ENFORCE(x_data_shape.data().has_value(),
"Value x comes from ShapeOp, it must have data");
IR_ENFORCE(y_data_shape.data().has_value(),
"Value y comes from ShapeOp, it must have data");
const auto &x_data = x_data_shape.data().value();
const auto &y_data = y_data_shape.data().value();
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily");

std::vector<symbol::DimExpr> output_data;
for (std::size_t i = 0; i < x_data.size(); ++i) {
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i)));
}

pir::OpResult res = result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs output_data_shape =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
return true;
}

} // namespace dialect
} // namespace paddle

Expand All @@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
#endif
33 changes: 33 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand Down Expand Up @@ -554,6 +555,37 @@ class Increment_Op
const std::vector<std::vector<bool>> &stop_gradients);
};

class IR_API ShapeBroadcastOp
: public pir::Op<ShapeBroadcastOp,
paddle::dialect::InferSymbolicShapeInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::GetKernelTypeForVarInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.shape_broadcast"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_,
pir::Value y_);

void VerifySig() {}

pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::OpResult out() { return result(0); }

static void InferMeta(phi::InferMetaContext *infer_meta);

static phi::DataType GetKernelTypeForVar(
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype);
Comment on lines +581 to +584
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有kernel的op应该不用加这个接口

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,后续提个 PR 删掉


bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
};

} // namespace dialect
} // namespace paddle

Expand All @@ -577,3 +609,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)