Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#23 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
redefine op_topo_pattern.ReductionPattern
  • Loading branch information
tc20042008 committed Mar 6, 2024
2 parents 6aa34d7 + 00729a9 commit 2427ca7
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions paddle/cinn/api/op_topo_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,32 @@ namespace cinn::api {
template <typename T>
struct InjectiveSourcePattern {};

// Reduce ops
// Reduce op
template <typename T>
struct ReductionPattern {};
struct SingleReductionOpPattern {};

// ElementWise/Broadcast ops which have shardable dimentions and reduction ancestors.
template <typename T>
struct PartialShardablePattern {};

// Reduce base pattern
template <typename T>
struct ReductionPattern {
using Nothing = std::monostate;
std::variant<Nothing, InjectiveSourcePattern<T>, PartialShardablePattern> opt_is_or_ps_input;
SingleReductionOpPattern<T> reduction_op_pattern;
};


// SR := [R | PS]
template <typename T>
using ShardableReductionsPattern = std::vector<std::variant<ReductionPattern<T>, PartialShardablePattern<T>>>;

// fuse rules:
// 1. IS * PS -> PS
// 2. PS * PS -> PS
// 3. PS * R -> R
// 4. IS * R -> R
// 3. IS * R -> R
// 4. PS * R -> R

// lifting rules:
// 1. R -> SR
Expand Down

0 comments on commit 2427ca7

Please sign in to comment.