Skip to content

Commit

Permalink
[Dynamic Shape] Add helper function MakeGenerateShapeOpAttribute (#60512
Browse files Browse the repository at this point in the history
)

* add helper function MakeGenerateShapeOpAttribute

* fix complier complaint

* Code format
  • Loading branch information
jiahy0825 committed Jan 3, 2024
1 parent 698bb42 commit 54b95ae
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 198 deletions.
235 changes: 235 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include <unordered_set>
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_attribute.h"

Expand Down Expand Up @@ -422,4 +423,238 @@ MakeGetterDimExpr4SymbolName(
};
}

namespace {

bool IsAtomicImpl(int64_t) { return true; }

bool IsAtomicImpl(const std::string&) { return true; }

bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }

bool IsAtomicImpl(const symbol::Broadcast<symbol::DimExpr>&) { return false; }

bool IsAtomic(const symbol::DimExpr& dim_expr) {
return std::visit([](const auto& impl) { return IsAtomicImpl(impl); },
dim_expr.variant());
}

bool InputDimExprsAllSupported(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
const auto& AllSupported =
[](const std::vector<symbol::DimExpr>& dim_exprs) -> bool {
for (const auto& dim_expr : dim_exprs) {
if (!IsAtomic(dim_expr)) return false;
}
return true;
};
for (const auto& input_tensor : input_tensors) {
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
if (!AllSupported(dim_exprs.shape())) return false;
if (dim_exprs.data().has_value()) {
if (!AllSupported(dim_exprs.data().value())) return false;
}
}
return true;
}

void ConvertDimExprToAttributes(pir::IrContext* ir_context,
const std::vector<symbol::DimExpr>& dim_exprs,
std::vector<pir::Attribute>* attrs) {
attrs->clear();
attrs->reserve(dim_exprs.size());
for (const auto& dim_expr : dim_exprs) {
attrs->emplace_back(ConvertDimExprToAttribute(ir_context, dim_expr));
}
}

void CollectSymbolNames(const symbol::DimExpr& dim_expr,
std::set<std::string>* ret);

void CollectSymbolNamesImpl(const int64_t& dim_expr,
std::set<std::string>* ret) {
// do nothing.
}

void CollectSymbolNamesImpl(const std::string& dim_expr,
std::set<std::string>* ret) {
ret->insert(dim_expr);
}

template <typename T>
void CollectSymbolNamesImplForUnary(const T& dim_expr,
std::set<std::string>* ret) {
const auto& [operand] = *dim_expr;
CollectSymbolNames(operand, ret);
}

void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForUnary(dim_expr, ret);
}

template <typename T>
void CollectSymbolNamesImplForVariadic(const T& dim_expr,
std::set<std::string>* ret) {
const auto& operands = *(dim_expr.operands);
for (const auto& operand : operands) {
CollectSymbolNames(operand, ret);
}
}

void CollectSymbolNamesImpl(const symbol::Add<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Min<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNamesImpl(const symbol::Broadcast<symbol::DimExpr>& dim_expr,
std::set<std::string>* ret) {
CollectSymbolNamesImplForVariadic(dim_expr, ret);
}

void CollectSymbolNames(const symbol::DimExpr& dim_expr,
std::set<std::string>* ret) {
return std::visit(
[&](const auto& impl) { return CollectSymbolNamesImpl(impl, ret); },
dim_expr.variant());
}

void CollectSymbolNames(const std::vector<symbol::DimExpr>& dim_exprs,
std::set<std::string>* ret) {
for (const auto& dim_expr : dim_exprs) {
CollectSymbolNames(dim_expr, ret);
}
}

template <typename SymbolBindingsT>
void AppendSymbolBindings(const std::vector<symbol::DimExpr>& dim_exprs,
const std::set<std::string>& symbol_names,
int in_tensor_idx,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
for (int in_tensor_dim_idx = 0; in_tensor_dim_idx < dim_exprs.size();
++in_tensor_dim_idx) {
const auto& dim_expr = dim_exprs.at(in_tensor_dim_idx);
CHECK(IsAtomic(dim_expr));
if (!dim_expr.isa<std::string>()) continue;
const auto& sym_name = dim_expr.dyn_cast<std::string>();
if (symbol_names.find(sym_name) == symbol_names.end()) continue;
symbol_bindings->emplace_back(SymbolBindingsT{
/*.symbol_name=*/sym_name,
/*.input_tensor_idx=*/in_tensor_idx,
/*.input_tensor_dim_idx=*/in_tensor_dim_idx,
});
}
}

void GenerateSymbolBindings(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors,
const std::set<std::string>& symbol_names,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
for (int i = 0; i < input_tensors.size(); ++i) {
const auto& input_tensor = input_tensors.at(i);
const auto& dim_exprs = ShapeOrDataDimExprs4Value(input_tensor);
AppendSymbolBindings<GenerateShapeOp::ShapeSymbolBinding>(
dim_exprs.shape(), symbol_names, i, symbol_bindings);
if (dim_exprs.data().has_value()) {
AppendSymbolBindings<GenerateShapeOp::DataSymbolBinding>(
dim_exprs.shape(), symbol_names, i, symbol_bindings);
}
}
}

std::vector<pir::Value> GetMinimalInputs(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
std::unordered_set<symbol::DimExpr> handdled_dim_exprs;
std::unordered_set<pir::Value> first_occurred_input_tensors;
auto TryCollectFirstOcurredInput_tensor =
[&](pir::Value input_tensor,
const std::vector<symbol::DimExpr>& dim_exprs) {
for (const auto& dim_expr : dim_exprs) {
if (dim_expr.isa<int64_t>()) continue;
if (!handdled_dim_exprs.insert(dim_expr).second) {
first_occurred_input_tensors.insert(input_tensor);
}
}
};
for (pir::Value input_tensor : input_tensors) {
const auto& shape_or_data_dim_exprs =
ShapeOrDataDimExprs4Value(input_tensor);
if (shape_or_data_dim_exprs.data().has_value()) {
TryCollectFirstOcurredInput_tensor(
input_tensor, shape_or_data_dim_exprs.data().value());
}
TryCollectFirstOcurredInput_tensor(input_tensor,
shape_or_data_dim_exprs.shape());
}
std::vector<pir::Value> ret{};
ret.reserve(input_tensors.size());
for (pir::Value input_tensor : input_tensors) {
if (first_occurred_input_tensors.count(input_tensor) > 0) {
ret.emplace_back(input_tensor);
}
}
return ret;
}

} // namespace

bool MakeGenerateShapeOpAttribute(
pir::IrContext* ir_context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<symbol::DimExpr>& out_dim_exprs,
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minial_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings) {
*minial_inputs = GetMinimalInputs(ShapeOrDataDimExprs4Value, origin_inputs);
if (!InputDimExprsAllSupported(ShapeOrDataDimExprs4Value, *minial_inputs)) {
VLOG(4) << "input dim_exprs are not as simple as symbols, please make sure "
"they are handled by other passes";
return false;
}
// generate output_dim_expr_attrs
ConvertDimExprToAttributes(
ir_context, out_dim_exprs, /*out*/ output_dim_expr_attrs);
// generate symbol_bindings
std::set<std::string> symbol_names_in_out_dim_exprs{};
CollectSymbolNames(out_dim_exprs, &symbol_names_in_out_dim_exprs);
GenerateSymbolBindings(ShapeOrDataDimExprs4Value,
*minial_inputs,
symbol_names_in_out_dim_exprs,
/*out*/ symbol_bindings);
return true;
}

} // namespace cinn::dialect
15 changes: 15 additions & 0 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#pragma once

#include <functional>
#include <optional>
#include <vector>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
Expand Down Expand Up @@ -46,4 +48,17 @@ MakeGetterDimExpr4SymbolName(
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;

// Returns true if success.
bool MakeGenerateShapeOpAttribute(
pir::IrContext* ir_context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<symbol::DimExpr>& out_dim_exprs,
const std::vector<pir::Value>& origin_inputs,
std::vector<pir::Value>* minial_inputs,
std::vector<pir::Attribute>* output_dim_expr_attrs,
GenerateShapeOp::SymbolBindings* symbol_bindings);

} // namespace cinn::dialect
Loading

0 comments on commit 54b95ae

Please sign in to comment.