Skip to content

Commit

Permalink
add split with variable in factors and rewrite vectorize,unroll,bind …
Browse files Browse the repository at this point in the history
…error handling mechanism (#60449)
  • Loading branch information
Courtesy-Xs committed Jan 2, 2024
1 parent 290bf41 commit bd29981
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 10 deletions.
26 changes: 22 additions & 4 deletions paddle/cinn/ir/schedule/impl/for_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,37 @@ void DyScheduleImpl::Parallel(const Expr& loop) {
}

void DyScheduleImpl::Vectorize(const Expr& loop, int factor) {
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Vectorize";
std::ostringstream os;
CHECK_GT(factor, 0) << "vectorize factor should be more than 0";
CHECK(loop.As<For>()->extent.is_constant())
<< "The loop to be vectorized should be constant!\n";
if (factor <= 0) {
os << "vectorize factor should be more than 0\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
if (!loop.As<For>()->extent.is_constant()) {
os << "The loop to be vectorized should be constant!\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
MutateForType(loop, ForType::Vectorized, factor);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

void DyScheduleImpl::Unroll(const Expr& loop) {
CHECK(loop.As<For>()->extent.is_constant())
<< "The loop to be unrolled should be constant!\n";
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Unroll";
std::ostringstream os;
if (!loop.As<For>()->extent.is_constant()) {
os << "The loop to be unrolled should be constant!\n";
throw IRScheduleErrorHandler(primitive, os.str(), module_expr_);
}
MutateForType(loop, ForType::Unrolled);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
#ifdef CINN_WITH_CUDA
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "Bind";
std::ostringstream os;

Expand Down Expand Up @@ -117,6 +134,7 @@ void DyScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
}
MutateForType(loop, ForType::GPUThread, offset);
}
CINN_IR_SCHEDULE_END(this->err_msg_level_);
#endif
}
} // namespace ir
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/schedule/impl/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class DyScheduleImpl : public ScheduleBase {
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down Expand Up @@ -122,6 +123,7 @@ class StScheduleImpl : public ScheduleBase {
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
std::vector<Expr> Split(const Expr& loop, const std::vector<Expr>& factors);
std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down
67 changes: 66 additions & 1 deletion paddle/cinn/ir/schedule/impl/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/macros.h"
#include "paddle/cinn/ir/schedule/impl/ir_schedule.h"

#include "paddle/cinn/common/integer_set.h"
#include "paddle/cinn/common/macros.h"

/** \brief A macro that guards the beginning of each implementation of schedule
*/
#define CINN_IR_SCHEDULE_BEGIN() try {
Expand Down Expand Up @@ -157,6 +159,63 @@ std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
return splited_loops;
}

// TODO(@LiuYang): now -1 can't exsit in factors,
std::vector<Expr> DyScheduleImpl::Split(const Expr& loop,
const std::vector<Expr>& factors) {
CHECK(loop.As<ir::For>())
<< "Expr param of Split must be For node! Please check.";
auto* for_node = loop.As<ir::For>();
CHECK(common::is_zero(for_node->min))
<< "The For node must start with 0! Please check.";
CHECK(!factors.empty())
<< "The factors param of Split should not be empty! Please check.";
CHECK(!loop.As<ir::For>()->extent.is_constant())
<< "Can't Split a loop with constant extent but with variable in "
"factors!";
Expr tot_extent = for_node->extent;

VLOG(3) << "Try Split loop from (" << for_node->loop_var->name << ", 0, "
<< tot_extent << ") to (" << cinn::utils::Join(factors, ", ")
<< ") at loop:\n"
<< loop;

std::vector<Expr> process_factors(factors);
Expr prod_size(1);
for (auto factor : factors) prod_size = prod_size * Expr(factor);
cinn::common::SymbolicExprAnalyzer analyzer({});
CHECK(analyzer.ProveEQ(tot_extent, prod_size).value_or(false))
<< "Product of factors can't be proved to be equal to the extent of "
"current for loop!";

std::vector<Var> new_loop_vars;
Expr substitute_value(0);
for (int i = 0; i < process_factors.size(); ++i) {
Var temp_var(common::UniqName(for_node->loop_var->name));
substitute_value = Expr(temp_var) + substitute_value * process_factors[i];
new_loop_vars.push_back(temp_var);
}
substitute_value = cinn::common::AutoSimplify(substitute_value);
Expr new_node = ir::ir_utils::IRCopy(for_node->body);
ReplaceExpr(&new_node, {for_node->loop_var}, {substitute_value});
std::vector<Expr> splited_loops;
splited_loops.resize(process_factors.size());

for (int i = process_factors.size() - 1; i >= 0; i--) {
if (!new_node.As<ir::Block>()) new_node = Block::Make({new_node});
new_node = For::Make(new_loop_vars[i],
Expr(0),
process_factors[i],
for_node->for_type(),
for_node->device_api,
new_node);
splited_loops[i] = new_node;
}

this->Replace(loop, new_node);
VLOG(3) << "After Split, ir is:\n" << splited_loops.at(0);
return splited_loops;
}

Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) {
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
std::vector<const ir::For*> for_nodes;
Expand Down Expand Up @@ -370,6 +429,12 @@ std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
return splited_loops;
}

std::vector<Expr> StScheduleImpl::Split(const Expr& loop,
const std::vector<Expr>& factors) {
CHECK(false) << "Static shape schedule don't support Split with some "
"variables in factors";
}

Expr StScheduleImpl::Fuse(const std::vector<Expr>& loops) {
VLOG(3) << "Tring to fuse:\n" << cinn::utils::Join(loops, "\n");
std::vector<const ir::For*> for_nodes;
Expand Down
15 changes: 10 additions & 5 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,11 +405,16 @@ std::vector<Expr> IRSchedule::Split(const std::string& block_name,
std::vector<Expr> IRSchedule::Split(const Expr& loop,
const std::vector<Expr>& factors) {
std::vector<int> int_factors;
std::transform(factors.begin(),
factors.end(),
std::back_inserter(int_factors),
[](Expr x) { return x.as_int32(); });
auto results = impl_->Split(loop, int_factors);
std::vector<Expr> results;
std::for_each(factors.begin(), factors.end(), [&int_factors](const Expr& e) {
if (e.is_constant()) int_factors.push_back(e.as_int32());
});
if (int_factors.size() == factors.size()) {
results = impl_->Split(loop, int_factors);
} else {
results = impl_->Split(loop, factors);
}

trace_.Append(ScheduleDesc::Step(
"Split",
{{"loop", std::vector<Expr>({loop})}, {"factors", factors}},
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/schedule/schedule_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class ScheduleBase {
virtual Expr GetBlock(const std::string& block_name) const = 0;
virtual std::vector<Expr> Split(const Expr& loop,
const std::vector<int>& factors) = 0;
virtual std::vector<Expr> Split(const Expr& loop,
const std::vector<Expr>& factors) = 0;
virtual std::vector<Expr> SamplePerfectTile(
utils::LinearRandomEngine::StateType* rand_seed,
const Expr& loop,
Expand Down

0 comments on commit bd29981

Please sign in to comment.