Skip to content

Commit

Permalink
[Dynamic Shape] Add FullyInsertBroadcastPass and Broadcast Op (#60511)
Browse files Browse the repository at this point in the history
* add ShapeBroadcastOp

* add pass FullyInsertBroadcastPass

* InferSymbolicShape of BroadcastShape Op

* Delete unit test

* Fix return error

* Code format

* Fix error message

* Update paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>

---------

Co-authored-by: Bo Zhang <105368690+zhangbopd@users.noreply.github.com>
  • Loading branch information
jiahy0825 and zhangbopd committed Jan 4, 2024
1 parent 6b2d74c commit a05f195
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 1 deletion.
9 changes: 9 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,13 @@ if(NOT CINN_ONLY)
cinn_op_dialect
op_dialect_vjp)

cinn_cc_library(
fully_insert_broadcast_pass
SRCS
fully_insert_broadcast_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp)

endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/common/ddim.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace cinn {
namespace dialect {
namespace ir {

pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
pir::Value x,
pir::Value y) {
pir::Value x_shape = rewriter->Build<paddle::dialect::ShapeOp>(x).out();
pir::Value y_shape = rewriter->Build<paddle::dialect::ShapeOp>(y).out();
return rewriter->Build<paddle::dialect::ShapeBroadcastOp>(x_shape, y_shape)
.out();
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
{
pir::Value broadcasted_x =
rewriter->Build<paddle::dialect::ExpandOp>(x, output_dim_tensor).out();
op->operand(0).set_source(broadcasted_x);
}
{
pir::Value broadcasted_y =
rewriter->Build<paddle::dialect::ExpandOp>(y, output_dim_tensor).out();
op->operand(1).set_source(broadcasted_y);
}
return true;
}

template <typename OPTYPE>
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;

bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
return ProcessOp(op, &rewriter);
}
};

FullyInsertBroadcastPass::FullyInsertBroadcastPass()
: pir::PatternRewritePass("fully_insert_broadcast_pass", 1) {}

pir::RewritePatternSet FullyInsertBroadcastPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::AddOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::SubtractOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::DivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::ElementwisePowOp>>(
context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::RemainderOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::FloorDivideOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MaximumOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::MinimumOp>>(context);

// compare ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::LessEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::EqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterThanOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::GreaterEqualOp>>(context);

// bitwise ops
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseOrOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseXorOp>>(context);
ps.Add<FullyInsertBroadcastPattern<paddle::dialect::BitwiseNotOp>>(context);

return ps;
}

bool FullyInsertBroadcastPass::CanApplyOn(pir::Operation* op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"

namespace cinn {
namespace dialect {
namespace ir {

class FullyInsertBroadcastPass : public pir::PatternRewritePass {
public:
FullyInsertBroadcastPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;
};

} // namespace ir
} // namespace dialect
} // namespace cinn
156 changes: 155 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
paddle::dialect::SliceArrayOp, paddle::dialect::SliceArrayDenseOp,
paddle::dialect::AssignArray_Op, paddle::dialect::ArrayToTensorOp,
paddle::dialect::SelectInputOp, paddle::dialect::IncrementOp,
paddle::dialect::Increment_Op
paddle::dialect::Increment_Op, paddle::dialect::ShapeBroadcastOp
#else

#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
Expand All @@ -35,6 +35,7 @@ paddle::dialect::AddNOp, paddle::dialect::AddN_Op,
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/phi/api/lib/data_type_set.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -2925,6 +2926,158 @@ phi::DataType Increment_Op::GetKernelTypeForVar(
return expected_kernel_dtype;
}

void ShapeBroadcastOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
pir::Value x_,
pir::Value y_) {
VLOG(4) << "Start build ShapeBroadcastOp";

VLOG(4) << "Builder construction inputs";
std::vector<pir::Value> argument_inputs = {x_, y_};
argument.AddInputs(argument_inputs);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";
paddle::dialect::DenseTensorType x =
x_.type().dyn_cast<paddle::dialect::DenseTensorType>();
paddle::dialect::DenseTensorType y =
y_.type().dyn_cast<paddle::dialect::DenseTensorType>();

VLOG(4) << "Builder construction dense_x";
paddle::dialect::IrTensor ir_tensor_x(
paddle::dialect::TransToPhiDataType(x.dtype()),
x.dims(),
x.data_layout(),
x.lod(),
x.offset());
VLOG(4) << "Builder construction meta_x";
paddle::dialect::IrMetaTensor meta_x(&ir_tensor_x);

VLOG(4) << "Builder construction dense_y";
paddle::dialect::IrTensor ir_tensor_y(
paddle::dialect::TransToPhiDataType(y.dtype()),
y.dims(),
y.data_layout(),
y.lod(),
y.offset());
VLOG(4) << "Builder construction meta_y";
paddle::dialect::IrMetaTensor meta_y(&ir_tensor_y);
paddle::dialect::IrTensor dense_out;
paddle::dialect::IrMetaTensor meta_out(&dense_out);

phi::ElementwiseInferMeta(meta_x, meta_y, &meta_out);

std::vector<pir::Type> argument_outputs;
pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get(
pir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_out.dtype()),
dense_out.dims(),
dense_out.layout(),
dense_out.lod(),
dense_out.offset());
argument_outputs.push_back(out_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
::pir::PassStopGradientsDefaultly(argument);
}

namespace {

void ShapeBroadcastOpInferMeta(const phi::MetaTensor &x,
const phi::MetaTensor &y,
phi::MetaTensor *out) {
PADDLE_ENFORCE_EQ(
x.dims().size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of x.dims() must be equal to 1.", x.dims().size()));
PADDLE_ENFORCE_EQ(
y.dims().size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of y.dims() must be equal to 1.", y.dims().size()));
out->set_dims({std::max<int64_t>(x.dims().at(0), y.dims().at(0))});
// dtype need promote when meet input dtype with more precision
paddle::experimental::DataTypeSet dtype_set{x.dtype()};
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype());
DataType promote_result = PromoteTypes(dtype_set);
if (promote_result == DataType::UNDEFINED) {
promote_result = x.dtype();
}
out->set_dtype(promote_result);
out->set_layout(x.layout());
out->share_lod(x);
}

} // namespace

void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(ShapeBroadcastOpInferMeta);
fn(infer_meta);
}

phi::DataType ShapeBroadcastOp::GetKernelTypeForVar(
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype) {
VLOG(4) << "Get KernelType for Var of op: ShapeBroadcastOp";

return expected_kernel_dtype;
}

namespace {

symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs,
const symbol::DimExpr &rhs) {
if (lhs.isa<std::int64_t>() && rhs.isa<std::int64_t>()) {
return std::max(lhs.dyn_cast<std::int64_t>(), rhs.dyn_cast<std::int64_t>());
} else if (lhs.isa<std::int64_t>()) {
return lhs.dyn_cast<std::int64_t>() == 1 ? rhs : lhs;
} else if (rhs.isa<std::int64_t>()) {
return rhs.dyn_cast<std::int64_t>() == 1 ? lhs : rhs;
} else {
return symbol::Broadcast<symbol::DimExpr>{
symbol::List<symbol::DimExpr>{lhs, rhs}};
}
LOG(FATAL) << "Dead code";
}

} // namespace

bool ShapeBroadcastOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis *shape_analysis) {
pir::Value x = operand_source(0);
pir::Value y = operand_source(1);
std::string x_id = pir::GetValueId(&x);
std::string y_id = pir::GetValueId(&y);

IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
"x_id does not exist.");
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
"y_id does not exist.");
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
IR_ENFORCE(x_data_shape.data().has_value(),
"Value x comes from ShapeOp, it must have data");
IR_ENFORCE(y_data_shape.data().has_value(),
"Value y comes from ShapeOp, it must have data");
const auto &x_data = x_data_shape.data().value();
const auto &y_data = y_data_shape.data().value();
IR_ENFORCE(x_data.size() == y_data.size(), "Support same rank temporarily");

std::vector<symbol::DimExpr> output_data;
for (std::size_t i = 0; i < x_data.size(); ++i) {
output_data.emplace_back(GetBroadcastDimExpr(x_data.at(i), y_data.at(i)));
}

pir::OpResult res = result(0);
std::string res_id = pir::GetValueId(&res);
symbol::ShapeOrDataDimExprs output_data_shape =
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
return true;
}

} // namespace dialect
} // namespace paddle

Expand All @@ -2948,4 +3101,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)
#endif
33 changes: 33 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand Down Expand Up @@ -554,6 +555,37 @@ class Increment_Op
const std::vector<std::vector<bool>> &stop_gradients);
};

class IR_API ShapeBroadcastOp
: public pir::Op<ShapeBroadcastOp,
paddle::dialect::InferSymbolicShapeInterface,
paddle::dialect::InferMetaInterface,
paddle::dialect::GetKernelTypeForVarInterface> {
public:
using Op::Op;
static const char *name() { return "pd_op.shape_broadcast"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value x_,
pir::Value y_);

void VerifySig() {}

pir::Value x() { return operand_source(0); }
pir::Value y() { return operand_source(1); }
pir::OpResult out() { return result(0); }

static void InferMeta(phi::InferMetaContext *infer_meta);

static phi::DataType GetKernelTypeForVar(
const std::string &var_name,
const phi::DataType &tensor_dtype,
const phi::DataType &expected_kernel_dtype);

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
};

} // namespace dialect
} // namespace paddle

Expand All @@ -577,3 +609,4 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ExpandOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectInputOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IncrementOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::Increment_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShapeBroadcastOp)

0 comments on commit a05f195

Please sign in to comment.