Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add align_extent(), to align extent but not min #5829

Merged
merged 3 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python_bindings/src/PyFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ void define_func(py::module &m) {
.def("set_estimates", &Func::set_estimates, py::arg("estimates"))

.def("align_bounds", &Func::align_bounds, py::arg("var"), py::arg("modulus"), py::arg("remainder") = 0)
.def("align_extent", &Func::align_extent, py::arg("var"), py::arg("modulus"))

.def("bound_extent", &Func::bound_extent, py::arg("var"), py::arg("extent"))

Expand Down
24 changes: 14 additions & 10 deletions src/AllocationBoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,20 @@ class AllocationInference : public IRMutator {
extent = simplify((max - min) + 1);
}
if (bound.modulus.defined()) {
internal_assert(bound.remainder.defined());
min -= bound.remainder;
min = (min / bound.modulus) * bound.modulus;
min += bound.remainder;
Expr max_plus_one = max + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
extent = simplify(max_plus_one - min);
max = max_plus_one - 1;
if (bound.remainder.defined()) {
min -= bound.remainder;
min = (min / bound.modulus) * bound.modulus;
min += bound.remainder;
Expr max_plus_one = max + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
extent = simplify(max_plus_one - min);
max = max_plus_one - 1;
} else {
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
max = simplify(min + extent - 1);
}
}

Expr min_var = Variable::make(Int(32), min_name);
Expand Down
22 changes: 14 additions & 8 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,20 @@ class BoundsInference : public IRMutator {
}

if (bound.modulus.defined()) {
min_required -= bound.remainder;
min_required = (min_required / bound.modulus) * bound.modulus;
min_required += bound.remainder;
Expr max_plus_one = max_required + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
max_required = max_plus_one - 1;
if (bound.remainder.defined()) {
min_required -= bound.remainder;
min_required = (min_required / bound.modulus) * bound.modulus;
min_required += bound.remainder;
Expr max_plus_one = max_required + 1;
max_plus_one -= bound.remainder;
max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus;
max_plus_one += bound.remainder;
max_required = max_plus_one - 1;
} else {
Expr extent = (max_required - min_required) + 1;
extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus);
max_required = simplify(min_required + extent - 1);
}
s = LetStmt::make(min_var, min_required, s);
s = LetStmt::make(max_var, max_required, s);
}
Expand Down
21 changes: 20 additions & 1 deletion src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2264,7 +2264,6 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {

// Reduce the remainder
remainder = remainder % modulus;

invalidate_cache();

bool found = func.is_pure_arg(var.name());
Expand All @@ -2279,6 +2278,26 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) {
return *this;
}

Func &Func::align_extent(const Var &var, Expr modulus) {
user_assert(modulus.defined()) << "modulus is undefined\n";
user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n";

modulus = cast<int32_t>(modulus);

invalidate_cache();

bool found = func.is_pure_arg(var.name());
user_assert(found)
<< "Can't align extent of variable " << var.name()
<< " of function " << name()
<< " because " << var.name()
<< " is not one of the pure variables of " << name() << ".\n";

Bound b = {var.name(), Expr(), Expr(), modulus, Expr()};
func.schedule().bounds().push_back(b);
return *this;
}

Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y,
const VarOrRVar &xo, const VarOrRVar &yo,
const VarOrRVar &xi, const VarOrRVar &yi,
Expand Down
12 changes: 11 additions & 1 deletion src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,9 +1551,19 @@ class Func {
* f.align_bounds(x, 2, 1) forces the min to be odd and the extent
* to be even. The region computed always contains the region that
* would have been computed without this directive, so no
* assertions are injected. */
* assertions are injected.
*/
Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0);

/** Expand the region computed so that the extent is a
* multiple of 'modulus'. For example, f.align_extent(x, 2) forces
* the extent realized to be even. The region computed always contains the
* region that would have been computed without this directive, so no
* assertions are injected. (This is essentially equivalent to align_bounds(),
* but always leaving the min untouched.)
*/
Func &align_extent(const Var &var, Expr modulus);

/** Bound the extent of a Func's realization, but not its
* min. This means the dimension can be unrolled or vectorized
* even when its min is not fixed (for example because it is
Expand Down
1 change: 1 addition & 0 deletions src/Generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2188,6 +2188,7 @@ class GeneratorOutputBase : public GIOBase {
// @{
HALIDE_FORWARD_METHOD(Func, add_trace_tag)
HALIDE_FORWARD_METHOD(Func, align_bounds)
HALIDE_FORWARD_METHOD(Func, align_extent)
HALIDE_FORWARD_METHOD(Func, align_storage)
HALIDE_FORWARD_METHOD_CONST(Func, args)
HALIDE_FORWARD_METHOD(Func, bound)
Expand Down
5 changes: 3 additions & 2 deletions src/Schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ struct Bound {

/** If defined, the number of iterations will be a multiple of
* "modulus", and the first iteration will be at a value congruent
* to "remainder" modulo "modulus". Set by Func::align_bounds. */
* to "remainder" modulo "modulus". Set by Func::align_bounds and
* Func::align_extent. */
Expr modulus, remainder;
};

Expand Down Expand Up @@ -557,7 +558,7 @@ class FuncSchedule {

/** You may explicitly bound some of the dimensions of a function,
* or constrain them to lie on multiples of a given factor. See
* \ref Func::bound and \ref Func::align_bounds */
* \ref Func::bound and \ref Func::align_bounds and \ref Func::align_extent. */
// @{
const std::vector<Bound> &bounds() const;
std::vector<Bound> &bounds();
Expand Down
51 changes: 51 additions & 0 deletions test/correctness/align_bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,57 @@ int main(int argc, char **argv) {
}
}

// Now try a case where we align the extent but not the min.
{
Func f, g, h;
Var x;

f(x) = 3;

g(x) = select(x % 2 == 0, f(x + 1), f(x - 1) + 8);

Param<int> p;
h(x) = g(x - p) + g(x + p);

f.compute_root();
g.compute_root().align_extent(x, 32).trace_realizations();

p.set(3);
h.set_custom_trace(my_trace);
Buffer<int> result = h.realize({10});

for (int i = 0; i < 10; i++) {
int correct = (i & 1) == 1 ? 6 : 22;
if (result(i) != correct) {
printf("result(%d) = %d instead of %d\n",
i, result(i), correct);
return -1;
}
}

// Now the min/max should stick to odd numbers
if (trace_min != -3 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}

// Increasing p by one should have no effect
p.set(4);
h.realize(result);
if (trace_min != -4 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}

// But increasing it again should cause a jump of two in the bounds computed.
p.set(5);
h.realize(result);
if (trace_min != -5 || trace_extent != 32) {
printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent);
return -1;
}
}

printf("Success!\n");
return 0;
}