diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h index 8febb35a20e6e..1273b0b37280a 100644 --- a/paddle/cinn/api/op_topo_pattern.h +++ b/paddle/cinn/api/op_topo_pattern.h @@ -8,14 +8,23 @@ namespace cinn::api { template struct InjectiveSourcePattern {}; -// Reduce ops +// Reduce op template -struct ReductionPattern {}; +struct SingleReductionOpPattern {}; // ElementWise/Broadcast ops which have shardable dimentions and reduction ancestors. template struct PartialShardablePattern {}; +// Reduce base pattern +template +struct ReductionPattern { + using Nothing = std::monostate; + std::variant, PartialShardablePattern> opt_is_or_ps_input; + SingleReductionOpPattern reduction_op_pattern; +}; + + // SR := [R | PS] template using ShardableReductionsPattern = std::vector, PartialShardablePattern>>; @@ -23,8 +32,8 @@ using ShardableReductionsPattern = std::vector, // fuse rules: // 1. IS * PS -> PS // 2. PS * PS -> PS -// 3. PS * R -> R -// 4. IS * R -> R +// 3. IS * R -> R +// 4. PS * R -> R // lifting rules: // 1. R -> SR