Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#41 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
implement FuseFilteredStmtPatterns
  • Loading branch information
tc20042008 committed Mar 8, 2024
2 parents a745eb0 + b1c9cb8 commit badeae6
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 83 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ template<>
struct ErrorPattern<frontend::FrontendPattern> {
explicit ErrorPattern(const ErrorPattern<frontend::FrontendPatterns>& other) = default;

const pir::Operation* op;
std::vector<const pir::Operation*> ops;
std::string error_string;
};

Expand Down
190 changes: 108 additions & 82 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "paddle/cinn/common/bfs_walker.h"
#include "paddle/cinn/hlir/framework/op.h"
#include <optional>
#include <typeinfo>
#include <algorithm>

namespace cinn::frontend {

Expand Down Expand Up @@ -148,15 +150,6 @@ class StmtFusionHelper {
return MultiFuse(IsISPattern, ConstructISPattern, stmts);
}

std::optional<ErrorGroupPattern> Fuse_IS_x_PS_2_PS(std::list<StmtPattern>* stmt_patterns) const {
return FuseIternalPattenPrototype(
stmt_patterns,
[](const StmtPattern& upstream, const StmtPattern& downstream){
return IsISPattern(upstream) && IsPSPattern(downstream);
}
);
}

std::optional<ErrorGroupPattern> Fuse_PS_x_PS_2_PS(std::list<StmtPattern>* stmt_patterns) const {
const auto ConstructPSPattern = [&](const auto& ops) {
const auto shardable_axes_signature = GetShardableAxesSignature(ops);
Expand All @@ -168,22 +161,88 @@ class StmtFusionHelper {
return MultiFuse(IsPSPattern, ConstructISPattern, stmts);
}

std::optional<ErrorGroupPattern> Fuse_IS_x_R_2_R(std::list<StmtPattern>* stmt_patterns) const {
return FuseIternalPattenPrototype(
stmt_patterns,
[](const StmtPattern& upstream, const StmtPattern& downstream){
return IsISPattern(upstream) && IsRPattern(downstream);
struct FusePolicy_IS_x_PS_2_PS {
static bool FuseCondition(const StmtPattern& upstream, const StmtPattern& downstream) {
return IsISPattern(upstream) && IsPSPattern(downstream);
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<IS>(upstream), std::get<PS>(downstream));
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const IS& upstream,
const PS& downstream) {
const auto& ops = [&]{
std::vector<const pir::Operation*> ops;
ops.insert(ops.end(), upstream.ops.begin(), upstream.ops.end());
ops.insert(ops.end(), downstream.ops.begin(), downstream.ops.end());
std::unique(ops.begin(), ops.end());
return ops;
}();
const auto& shardable_axes_signature = MergeShardableAxesSignature(upstream, downstream);
return PS{
.ops=ops,
.shardable_axes_signature=shardable_axes_signature,
};
}
};

std::optional<ErrorGroupPattern> Fuse_IS_x_PS_2_PS(std::list<StmtPattern>* stmt_patterns) const {
return FuseFilteredStmtPatterns<FusePolicy_IS_x_PS_2_PS>(stmt_patterns);
}

struct FusePolicy_IS_x_R_2_R {
static bool FuseCondition(const StmtPattern& upstream, const StmtPattern& downstream) {
return IsISPattern(upstream) && IsRPattern(downstream);
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<IS>(upstream), std::get<R>(downstream));
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const IS& upstream,
const R& downstream) {
if (downstream.opt_inputs.has_value()) {
return ErrorGroupPattern{
.ops={downstream.reduction_op_pattern.reduce_op},
.error_string="The input of reduce has been fused.",
};
}
);
R new_pattern = R(downstream);
new_pattern.opt_inputs = upstream;
return new_pattern;
}
};

std::optional<ErrorGroupPattern> Fuse_IS_x_R_2_R(std::list<StmtPattern>* stmt_patterns) const {
return FuseFilteredStmtPatterns<FusePolicy_IS_x_R_2_R>(stmt_patterns);
}

std::optional<ErrorGroupPattern> Fuse_PS_x_R_2_R(std::list<StmtPattern>* stmt_patterns) const {
return FuseIternalPattenPrototype(
stmt_patterns,
[](const StmtPattern& upstream, const StmtPattern& downstream){
return IsPSPattern(upstream) && IsRPattern(downstream);
struct FusePolicy_PS_x_R_2_R {
static bool FuseCondition(const StmtPattern& upstream, const StmtPattern& downstream) {
return IsISPattern(upstream) && IsRPattern(downstream);
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const StmtPattern& upstream, const StmtPattern& downstream) {
return MergePatternImpl(std::get<PS>(upstream), std::get<R>(downstream));
}
static std::variant<StmtPattern, ErrorGroupPattern> MergePatternImpl(
const PS& upstream,
const R& downstream) {
if (downstream.opt_inputs.has_value()) {
return ErrorGroupPattern{
.ops={downstream.reduction_op_pattern.reduce_op},
.error_string="The input of reduce has been fused.",
};
}
);
R new_pattern = R(downstream);
new_pattern.opt_inputs = upstream;
return new_pattern;
}
};

std::optional<ErrorGroupPattern> Fuse_PS_x_R_2_R(std::list<StmtPattern>* stmt_patterns) const {
return FuseFilteredStmtPatterns<FusePolicy_PS_x_R_2_R>(stmt_patterns);
}

private:
Expand Down Expand Up @@ -398,81 +457,48 @@ class StmtFusionHelper {
LOG(FATAL) << "TODO(wuzhanfei).";
}

std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const IS& upstream,
const PS& downstream){
PS new_pattern = PS(downstream);
new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end());
return new_pattern;
}

std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const PS& upstream,
const PS& downstream){
PS new_pattern = PS(downstream);
new_pattern.ops.insert(new_pattern.end(), upstream.begin(), upstream.end());
return new_pattern
}

std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const IS& upstream,
const R& downstream){
R new_pattern = R(downstream);
new_pattern.opt_inputs = IS(upstream);
return new_pattern;
}

std::variant<StmtPattern, ErrorGroupPattern> MergePattern(
const PS& upstream,
const R& downstream){
R new_pattern = R(downstream);
new_pattern.opt_inputs = PS(upstream);
return new_pattern;
}
struct StmtIterPair {
StmtIter upstream_iter;
StmtIter downstream_iter;
};

std::optional<std::pair<StmtPattern, StmtPattern>> FindConnetedPattenPairWithCondition(
template <typename FuseTargetConditionT>
std::optional<StmtIterPair> FindConnetedPattenPairWithCondition(
std::list<StmtPattern>* stmt_patterns,
std::function<bool(const StmtPattern& upstream, const StmtPattern& downstream)>& FuseTargetCondition) const {
for (int i=0; i<stmt_patterns.size(); i++){
for (int j=i+1; j<stmt_patterns.size(); j++){
bool i_used_j = FirstIsUpstreamOfSecond(stmt_patterns[j], stmt_patterns[i]);
bool j_used_i = FirstIsUpstreamOfSecond(stmt_patterns[i], stmt_patterns[j]);

if (i_used_j && FuseTargetCondition(stmt_patterns[j], stmt_patterns[i])){
return std::make_pair(stmt_patterns[j], stmt_patterns[i]);
}else if(j_used_i && FuseTargetCondition(stmt_patterns[i], stmt_patterns[j])){
return std::make_pair(stmt_patterns[i], stmt_patterns[j]);
}else{
continue;
const FuseTargetConditionT& FuseTargetCondition) const {
for (auto dst_iter = stmt_patterns->begin(); dst_iter != stmt_patterns->end(); ++dst_iter) {
for (auto src_iter = stmt_patterns->begin(); src_iter != stmt_patterns->end(); ++src_iter) {
if (src_iter == dst_iter) continue;
if (!IsConnected(*src_iter, *dst_iter)) continue;
if (FuseTargetCondition(*src_iter, *dst_iter)) {
return StmtPattern{
.upstream_iter=src_iter,
.downstream_iter=dst_iter,
}
}
}
}
return std::nullopt;
}

std::optional<ErrorGroupPattern> FuseIternalPattenPrototype(
std::list<StmtPattern>* stmt_patterns,
std::function<bool(const StmtPattern&, const StmtPattern&)>& FuseTargetCondition) const{

template <typename FusionPolicy>
std::optional<ErrorGroupPattern> FuseFilteredStmtPatterns(
std::list<StmtPattern>* stmt_patterns) const{
while(true){
const auto& pattern_pair = FindConnetedPattenPairWithCondition(
stmt_patterns, FuseTargetCondition
);
if (!pattern_pair.value()){
break;
}
stmt_patterns, &FusionPolicy::FuseCondition);
if (!pattern_pair.value()) break;
const std::variant<StmtPattern, ErrorGroupPattern>& new_pattern =
MergePattern(pattern_pair.first, pattern_pair.second);
FusionPolicy::MergePattern(*pattern_pair.value().upstream_iter, *pattern_pair.value().downstream_iter);

if (IsErrorGroupPattern(new_pattern)){
return new_pattern;
if (std::holds_alternative<ErrorGroupPattern>(new_pattern)){
return std::get<ErrorGroupPattern>(new_pattern);
}

iternal_patterns.erase(pattern_pair.first);
iternal_patterns.erase(pattern_pair.second);
stmt_patterns->emplace_back(new_pattern);
stmt_patterns->erase(pattern_pair.value().upstream_iter);
stmt_patterns->erase(pattern_pair.value().downstream_iter);
stmt_patterns->emplace_back(std::get<StmtPattern>(new_pattern));
}
return {};
return std::nullopt;
}

ShardableAxesSignature GetShardableAxesSignature(const std::vector<const pir::Operation*>& ops) const {
Expand Down

0 comments on commit badeae6

Please sign in to comment.