Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add align_extent(), to align extent but not min #5829

Merged
merged 3 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/AllocationBoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,20 @@ class AllocationInference : public IRMutator {
extent = simplify((max - min) + 1);
}
if (bound.modulus.defined()) {
internal_assert(bound.remainder.defined());
min -= bound.remainder;
min = (min / bound.modulus) * bound.modulus;
min += bound.remainder;
Expr max_plus_one = max + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
extent = simplify(max_plus_one - min);
max = max_plus_one - 1;
if (bound.remainder.defined()) {
min -= bound.remainder;
min = (min / bound.modulus) * bound.modulus;
min += bound.remainder;
Expr max_plus_one = max + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
extent = simplify(max_plus_one - min);
max = max_plus_one - 1;
} else {
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
max = simplify(min + extent - 1);
}
}

Expr min_var = Variable::make(Int(32), min_name);
Expand Down
22 changes: 14 additions & 8 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,20 @@ class BoundsInference : public IRMutator {
}

if (bound.modulus.defined()) {
min_required -= bound.remainder;
min_required = (min_required / bound.modulus) * bound.modulus;
min_required += bound.remainder;
Expr max_plus_one = max_required + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
max_required = max_plus_one - 1;
if (bound.remainder.defined()) {
min_required -= bound.remainder;
min_required = (min_required / bound.modulus) * bound.modulus;
min_required += bound.remainder;
Expr max_plus_one = max_required + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
max_required = max_plus_one - 1;
} else {
Expr extent = (max_required - min_required) + 1;
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
max_required = simplify(min_required + extent - 1);
}
s = LetStmt::make(min_var, min_required, s);
s = LetStmt::make(max_var, max_required, s);
}
Expand Down
12 changes: 7 additions & 5 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2255,15 +2255,17 @@ Func &Func::bound_extent(const Var &var, Expr extent) {

Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {
user_assert(modulus.defined()) << "modulus is undefined\n";
user_assert(remainder.defined()) << "remainder is undefined\n";
user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n";
user_assert(Int(32).can_represent(remainder.type())) << "Can't represent remainder as int32\n";

modulus = cast<int32_t>(modulus);
remainder = cast<int32_t>(remainder);
if (remainder.defined()) {
user_assert(Int(32).can_represent(remainder.type())) << "Can't represent remainder as int32\n";

// Reduce the remainder
remainder = remainder % modulus;
remainder = cast<int32_t>(remainder);

// Reduce the remainder
remainder = remainder % modulus;
}

invalidate_cache();

Expand Down
4 changes: 3 additions & 1 deletion src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,9 @@ class Func {
* f.align_bounds(x, 2, 1) forces the min to be odd and the extent
* to be even. The region computed always contains the region that
* would have been computed without this directive, so no
* assertions are injected. */
* assertions are injected. If your pass an undefined Expr() for remainder,
* only the extent will be modified as above; the min will remain untouched.
*/
Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0);

/** Bound the extent of a Func's realization, but not its
Expand Down
51 changes: 51 additions & 0 deletions test/correctness/align_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,57 @@ int main(int argc, char **argv) {
}
}

// Now try a case where we align the extent but not the min.
{
Func f, g, h;
Var x;

f(x) = 3;

g(x) = select(x % 2 == 0, f(x + 1), f(x - 1) + 8);

Param<int> p;
h(x) = g(x - p) + g(x + p);

f.compute_root();
g.compute_root().align_bounds(x, 32, Expr()).trace_realizations();

p.set(3);
h.set_custom_trace(my_trace);
Buffer<int> result = h.realize({10});

for (int i = 0; i < 10; i++) {
int correct = (i & 1) == 1 ? 6 : 22;
if (result(i) != correct) {
printf("result(%d) = %d instead of %d\n",
i, result(i), correct);
return -1;
}
}

// Now the min/max should stick to odd numbers
if (trace_min != -3 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}

// Increasing p by one should have no effect
p.set(4);
h.realize(result);
if (trace_min != -4 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}

// But increasing it again should cause a jump of two in the bounds computed.
p.set(5);
h.realize(result);
if (trace_min != -5 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
}

printf("Success!\n");
return 0;
}