Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for codegen #18

Merged
merged 1 commit into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 67 additions & 45 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ class IterMapRewriter : public ExprMutator {
return NormalizeToIterWithOffset(ToIterSumExpr(DirectMutate(expr)));
}

IterSumExpr RewriteIterConstraint(const PrimExpr& expr, const PrimExpr& predicate_induced_min,
const PrimExpr& predicate_induced_max) {
IterSumExpr RewriteIterConstraint(const PrimExpr& expr,
const Optional<PrimExpr>& predicate_induced_min,
const Optional<PrimExpr>& predicate_induced_max) {
return NormalizeToIterOnBoundExpr(ToIterSumExpr(DirectMutate(expr)), predicate_induced_min,
predicate_induced_max);
}
Expand Down Expand Up @@ -494,14 +495,16 @@ class IterMapRewriter : public ExprMutator {
* \param predicate_induced_max Open upper bound from iter constraint, maybe undefined.
* \return The Normalized expression.
*/
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, PrimExpr predicate_induced_min,
PrimExpr predicate_induced_max) {
IterSumExpr NormalizeToIterOnBoundExpr(IterSumExpr expr, Optional<PrimExpr> predicate_induced_min,
Optional<PrimExpr> predicate_induced_max) {
// normalize to zero base
PrimExpr base = expr->base;
if (!is_zero(base)) {
expr.CopyOnWrite()->base = 0;
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
if (predicate_induced_min.defined())
predicate_induced_min = predicate_induced_min.value() - base;
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
Expand All @@ -521,27 +524,28 @@ class IterMapRewriter : public ExprMutator {
PrimExpr iter_min = mark_offset;
PrimExpr iter_max = iter_min + mark->extent;
if (predicate_induced_min.defined()) {
iter_min = max(predicate_induced_min, iter_min);
iter_min = max(predicate_induced_min.value(), iter_min);
}
if (predicate_induced_max.defined()) {
iter_max = min(predicate_induced_max, iter_max);
iter_max = min(predicate_induced_max.value(), iter_max);
}
if (!is_zero(iter_min)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
if (analyzer_->CanProve(iter_min <= iter_max)) {
if (!is_zero(iter_min)) {
// structured form's offset should be updated
flattened_map_.erase(structured_form);
structured_form.CopyOnWrite()->base = -iter_min;
mark.CopyOnWrite()->source = structured_form;
flattened_map_[structured_form] = flattened_form;
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};
// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
mark.CopyOnWrite()->extent = iter_max - iter_min;
sum_fuse_map_[flattened_form] = {mark, iter_min};

// we need to note down the flattened form of constrained iterators
// to check the validity of constraints, see also CheckConstraints()
constrained_iters_flattened_.push_back(flattened_form);
expr.CopyOnWrite()->args = Array<IterSplitExpr>({split});
expr.CopyOnWrite()->base = base + iter_min;
return expr;
}
Fail(Diagnostic::Error(expr->span)
<< "Fail to normalize " << expr << " with predicate bound [" << predicate_induced_min
Expand Down Expand Up @@ -608,7 +612,7 @@ class IterMapRewriter : public ExprMutator {
}
}
}
if (!base_scale) {
if (!base_scale || base_scale.value()->value < 0) {
diag_ctx_.Emit(Diagnostic::Error(expr->span)
<< "Fuse iters failed, can not find a valid base scale");
return NullOpt;
Expand Down Expand Up @@ -770,14 +774,15 @@ class IterMapRewriter : public ExprMutator {
struct IterConstraint {
// The expr of the iter
PrimExpr iter;
// The expr of the lower_bound
PrimExpr lower_bound;
// The expr of the upper_bound
PrimExpr upper_bound;
// The expr of the lower_bound, maybe undefined
Optional<PrimExpr> lower_bound;
// The expr of the upper_bound, maybe undefined
Optional<PrimExpr> upper_bound;
// The size of the iter, which is the number of nodes
size_t expr_size = 0;

IterConstraint(PrimExpr iter, PrimExpr lower_bound, PrimExpr upper_bound, size_t size)
IterConstraint(PrimExpr iter, Optional<PrimExpr> lower_bound, Optional<PrimExpr> upper_bound,
size_t size)
: iter(std::move(iter)),
lower_bound(std::move(lower_bound)),
upper_bound(std::move(upper_bound)),
Expand All @@ -787,11 +792,11 @@ struct IterConstraint {
/*!
* \brief Split the predicate into `(a < b) && (c < d) && ...`
* \param pred The predicate to be split.
* \param result The result of predicate split.
* \return A list of IterConstraint, empty if the split failed.
*/
std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
const Map<Var, Range>& input_iters) {
std::vector<IterConstraint> result;
bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>& input_iters,
std::vector<IterConstraint>& result) {
arith::PVar<PrimExpr> lhs, rhs, rest;
for (;;) {
// try extract comparisions
Expand Down Expand Up @@ -820,14 +825,14 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
is_equal = true;
is_finish = true;
} else {
return std::vector<IterConstraint>();
return false;
}
PrimExpr lhs_expr = lhs.Eval();
PrimExpr rhs_expr = rhs.Eval();
// we only accept predicate of integers
if (!((lhs_expr->dtype.is_int() || lhs_expr->dtype.is_uint()) &&
(rhs_expr->dtype.is_int() || rhs_expr->dtype.is_uint()))) {
return std::vector<IterConstraint>();
return false;
}
// determine iter and bound, if we can not distinguish them simply,
// try divide (lhs - rhs) into itervar aware and itervar free parts
Expand Down Expand Up @@ -863,35 +868,49 @@ std::vector<IterConstraint> MatchBoundConstraints(PrimExpr pred,
lhs_expr = analyzer.Simplify(lhs_expr);
rhs_expr = analyzer.Simplify(rhs_expr);
}
PrimExpr lower_bound, upper_bound, iter;
Optional<PrimExpr> lower_bound = NullOpt, upper_bound = NullOpt;
PrimExpr iter;
if (is_greater) {
if (bound_at_left) {
// bound > iter
// bound > iter / bound >= iter
upper_bound = is_equal ? lhs_expr + 1 : lhs_expr;
iter = rhs_expr;
} else {
// iter > bound
// iter > bound / iter >= bound
lower_bound = is_equal ? rhs_expr : rhs_expr + 1;
iter = lhs_expr;
}
} else {
if (bound_at_left) {
// bound < iter
// bound < iter / bound <= iter
lower_bound = is_equal ? lhs_expr : lhs_expr + 1;
iter = rhs_expr;
} else {
// iter < bound
// iter < bound / iter <= bound
upper_bound = is_equal ? rhs_expr + 1 : rhs_expr;
iter = lhs_expr;
}
}
result.emplace_back(iter, lower_bound, upper_bound, 0);
// If it is a predicate for input iters
if (const auto* var_ptr = iter.as<VarNode>()) {
auto it = input_iters.find(GetRef<Var>(var_ptr));
if (it == input_iters.end()) {
return false;
}
PrimExpr iter_min = (*it).second->min;
PrimExpr iter_max = (*it).second->min + (*it).second->extent;
if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value());
if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value());
input_iters.Set(GetRef<Var>(var_ptr), Range(iter_min, iter_max));
} else {
result.emplace_back(iter, lower_bound, upper_bound, 0);
}
if (is_finish) {
break;
}
pred = rest.Eval();
}
return result;
return true;
}

bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
Expand All @@ -911,8 +930,10 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
// - Step0: IterMapRewriter rewrites the expression to use IterMapExpr patterns.
// - Step1: IterIndependenceChecker checks if the iterator are independent.
if (!IterRangeSanityCheck(input_iters)) return Array<IterSumExpr>();
std::vector<IterConstraint> constraints = MatchBoundConstraints(predicate, input_iters);
if (!is_one(predicate) && constraints.empty()) {
Map<Var, Range> constrained_input_iters = input_iters;
std::vector<IterConstraint> constraints;
if (!is_one(predicate) &&
!MatchBoundConstraints(predicate, constrained_input_iters, constraints)) {
diag_ctx.Emit(Diagnostic::Error(predicate->span)
<< "Fail to collect constraints from iteration predicate: " << predicate);
return Array<IterSumExpr>();
Expand All @@ -929,10 +950,11 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, input_iters, diag_ctx);
IterMapRewriter rewriter(analyzer, constrained_input_iters, diag_ctx);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound, constraint.upper_bound);
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
constraint.upper_bound);
if (rewriter.unresolved_count() != 0) return Array<IterSumExpr>();
}
if (!rewriter.CheckConstraints()) {
Expand Down
3 changes: 2 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,15 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::InjectVirtualThread());
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::StorageRewrite());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::UnrollLoop());

// Add user-defined phase-2 passes
pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end());

// PHASE 3
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RenormalizeSplitPattern());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
pass_list.push_back(tir::transform::HoistIfThenElse());
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
if (!is_one(factor)) substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand Down
32 changes: 29 additions & 3 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
return result;
}

NDIntSet NDIntSetEval(Region region, PrimExpr predicate,
std::unordered_map<const VarNode*, arith::IntSet>& dom_map,
arith::Analyzer* analyzer) {
std::unordered_map<Var, Range, ObjectPtrHash, ObjectEqual> var_dom;
for (const auto& it : dom_map) {
var_dom[GetRef<Var>(it.first)] = it.second.CoverRange(Range::FromMinExtent(0, 0));
}
Optional<Array<arith::IntSet>> eval_res =
arith::EstimateRegionLowerBound(region, var_dom, predicate, analyzer);
if (eval_res.defined()) {
NDIntSet res(0);
for (const auto& it : eval_res.value()) res.push_back(it);
return res;
}
return support::NDIntSetEval(support::NDIntSetFromRegion(region), dom_map);
}

/*!
* \brief Collect the access region of each buffer.
* \note The param buffer regions will not be collected.
Expand Down Expand Up @@ -149,7 +166,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
return;
}
return StmtExprVisitor::VisitExpr_(op);
StmtExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const BlockNode* op) final {
Expand Down Expand Up @@ -198,6 +215,13 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
}
}

void VisitStmt_(const BlockRealizeNode* op) final {
PrimExpr cur_predicate = predicate_in_scope;
predicate_in_scope = op->predicate;
StmtExprVisitor::VisitStmt_(op);
predicate_in_scope = cur_predicate;
}

/**************** Helper functions ****************/

void VisitBufferAccess(const BufferRegion& buffer_region) {
Expand All @@ -206,7 +230,6 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
if (it != buffer_var_in_scope_.end()) {
const Buffer& buffer = it->second.first;
size_t n_ancestor_loops = it->second.second;
NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region);
// Step 1. Stop ancestor loop vars out of the allocation block from
// being relaxed unless NeedRelaxThread() is true.
std::vector<arith::IntSet> non_relaxed(n_ancestor_loops);
Expand All @@ -222,7 +245,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
dom_map_.erase(dom_it);
}
// Step 2. Relax the access region
nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_);
NDIntSet nd_int_set =
NDIntSetEval(buffer_region->region, predicate_in_scope, dom_map_, &dom_analyzer_);
// Step 3. Restore the non-relaxed ancestor loops domain
for (size_t i = 0; i < n_ancestor_loops; ++i) {
const VarNode* v = ancestor_loops_[i]->loop_var.get();
Expand Down Expand Up @@ -279,6 +303,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
*/
std::unordered_map<Var, std::pair<Buffer, size_t>, ObjectPtrHash, ObjectPtrEqual>
buffer_var_in_scope_;
/*! \brief The block predicate of current scope */
PrimExpr predicate_in_scope{true};

/*! \brief The map from loop vars to their iter range. */
std::unordered_map<const VarNode*, arith::IntSet> dom_map_;
Expand Down
Loading