Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jul 10, 2019
1 parent 885b182 commit 974286a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 13 deletions.
11 changes: 4 additions & 7 deletions src/arithmetic/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,13 @@ inline IntervalSet Combine<ir::Div>(Analyzer* analyzer,
}
if (is_one(b->min_value)) return a;
// no relaxation is needed in here due to set is inclusive
if (is_zero(b->min_value)) {
// if b is in [0, k] and we assume that be won't be zero
// then we can directly return a.
CHECK(!is_zero(b->max_value)) << "Divide by zero";
return a;
} else if (analyzer->CanProveGreaterEqual(b->min_value, 1)) {
// NOTE: given b is a single point, we can allow division to error out.
// So it is not harmful to include zero point.
if (analyzer->CanProveGreaterEqual(b->min_value, 0)) {
Expr min_value = a->HasLowerBound() ? a->min_value / b->min_value : neg_inf();
Expr max_value = a->HasUpperBound() ? a->max_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 1)) {
} else if (analyzer->CanProveGreaterEqual(-b->min_value, 0)) {
Expr min_value = a->HasUpperBound() ? a->max_value / b->min_value : neg_inf();
Expr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf();
return IntervalSet(min_value, max_value);
Expand Down
12 changes: 6 additions & 6 deletions src/arithmetic/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,8 +1024,8 @@ Mutate_(const Min* op, const Expr& self) {
}

// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
// Divide up rounding: trunc div
// NOTE: truncdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(min(((x + c1) / c2) * c2, x), x,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
Expand Down Expand Up @@ -1209,8 +1209,8 @@ Mutate_(const Max* op, const Expr& self) {
}

// DivMod rules
// Divide up rounding: truc div
// NOTE: trucdiv(x, y) >= floordiv(x, y)
// Divide up rounding: trunc div
// NOTE: truncdiv(x, y) >= floordiv(x, y)
TVM_TRY_REWRITE_IF(max(((x + c1) / c2) * c2, x), ((x + c1) / c2) * c2,
c2.Eval()->value > 0 &&
c1.Eval()->value + 1 == c2.Eval()->value);
Expand Down Expand Up @@ -1425,7 +1425,7 @@ Mutate_(const LT* op, const Expr& self) {
c1.Eval()->value < 0);

// constant cancelation: only need to make use of one mod
// truc div
// trunc div
TVM_TRY_REWRITE_IF(x * c2 < c1, x < (c1 - 1) / c2 + 1,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
Expand Down Expand Up @@ -1457,7 +1457,7 @@ Mutate_(const LT* op, const Expr& self) {
c1.Eval()->value >= 0 &&
c2.Eval()->value < 0);
// DivMod rules
// trucdiv
// truncdiv
TVM_TRY_REWRITE_IF(x / c1 < c2, x < c1 * c2,
c1.Eval()->value > 0 &&
c2.Eval()->value > 0);
Expand Down

0 comments on commit 974286a

Please sign in to comment.