From 4f152f32a3b34aea4fb0f4ec18dd8211e5a052a4 Mon Sep 17 00:00:00 2001 From: Steven Johnson Date: Sat, 27 Mar 2021 16:16:45 -0700 Subject: [PATCH] Add align_extent(), to align extent but not min (#5829) * Allow align_bounds() to align extent but not min This can be handy when you have an intermediate Func that is being tiled inside an outer Func and you want to ensure that it fits an exact multiple of tiles. * Add separate align_extent() method --- python_bindings/src/PyFunc.cpp | 1 + src/AllocationBoundsInference.cpp | 24 +++++++++------ src/BoundsInference.cpp | 22 ++++++++----- src/Func.cpp | 21 ++++++++++++- src/Func.h | 12 +++++++- src/Generator.h | 1 + src/Schedule.h | 5 +-- test/correctness/align_bounds.cpp | 51 +++++++++++++++++++++++++++++++ 8 files changed, 115 insertions(+), 22 deletions(-) diff --git a/python_bindings/src/PyFunc.cpp b/python_bindings/src/PyFunc.cpp index a50f5a719cf1..a6c05cc1dd02 100644 --- a/python_bindings/src/PyFunc.cpp +++ b/python_bindings/src/PyFunc.cpp @@ -325,6 +325,7 @@ void define_func(py::module &m) { .def("set_estimates", &Func::set_estimates, py::arg("estimates")) .def("align_bounds", &Func::align_bounds, py::arg("var"), py::arg("modulus"), py::arg("remainder") = 0) + .def("align_extent", &Func::align_extent, py::arg("var"), py::arg("modulus")) .def("bound_extent", &Func::bound_extent, py::arg("var"), py::arg("extent")) diff --git a/src/AllocationBoundsInference.cpp b/src/AllocationBoundsInference.cpp index accb9b921d2b..7ca63e1ed262 100644 --- a/src/AllocationBoundsInference.cpp +++ b/src/AllocationBoundsInference.cpp @@ -83,16 +83,20 @@ class AllocationInference : public IRMutator { extent = simplify((max - min) + 1); } if (bound.modulus.defined()) { - internal_assert(bound.remainder.defined()); - min -= bound.remainder; - min = (min / bound.modulus) * bound.modulus; - min += bound.remainder; - Expr max_plus_one = max + 1; - max_plus_one -= bound.remainder; - max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; - max_plus_one += bound.remainder; - extent = simplify(max_plus_one - min); - max = max_plus_one - 1; + if (bound.remainder.defined()) { + min -= bound.remainder; + min = (min / bound.modulus) * bound.modulus; + min += bound.remainder; + Expr max_plus_one = max + 1; + max_plus_one -= bound.remainder; + max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; + max_plus_one += bound.remainder; + extent = simplify(max_plus_one - min); + max = max_plus_one - 1; + } else { + extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus); + max = simplify(min + extent - 1); + } } Expr min_var = Variable::make(Int(32), min_name); diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 2f10b2193542..99848023e0f9 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -578,14 +578,20 @@ class BoundsInference : public IRMutator { } if (bound.modulus.defined()) { - min_required -= bound.remainder; - min_required = (min_required / bound.modulus) * bound.modulus; - min_required += bound.remainder; - Expr max_plus_one = max_required + 1; - max_plus_one -= bound.remainder; - max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; - max_plus_one += bound.remainder; - max_required = max_plus_one - 1; + if (bound.remainder.defined()) { + min_required -= bound.remainder; + min_required = (min_required / bound.modulus) * bound.modulus; + min_required += bound.remainder; + Expr max_plus_one = max_required + 1; + max_plus_one -= bound.remainder; + max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; + max_plus_one += bound.remainder; + max_required = max_plus_one - 1; + } else { + Expr extent = (max_required - min_required) + 1; + extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus); + max_required = simplify(min_required + extent - 1); + } s = LetStmt::make(min_var, min_required, s); s = LetStmt::make(max_var, max_required, s); } diff --git a/src/Func.cpp b/src/Func.cpp index a2df2e3adec8..18c98175d019 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -2264,7 +2264,6 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) { // Reduce the remainder remainder = remainder % modulus; - invalidate_cache(); bool found = func.is_pure_arg(var.name()); @@ -2279,6 +2278,26 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) { return *this; } +Func &Func::align_extent(const Var &var, Expr modulus) { + user_assert(modulus.defined()) << "modulus is undefined\n"; + user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n"; + + modulus = cast(modulus); + + invalidate_cache(); + + bool found = func.is_pure_arg(var.name()); + user_assert(found) + << "Can't align extent of variable " << var.name() + << " of function " << name() + << " because " << var.name() + << " is not one of the pure variables of " << name() << ".\n"; + + Bound b = {var.name(), Expr(), Expr(), modulus, Expr()}; + func.schedule().bounds().push_back(b); + return *this; +} + Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &xo, const VarOrRVar &yo, const VarOrRVar &xi, const VarOrRVar &yi, diff --git a/src/Func.h b/src/Func.h index 34247d0811f8..4602a033923c 100644 --- a/src/Func.h +++ b/src/Func.h @@ -1551,9 +1551,19 @@ class Func { * f.align_bounds(x, 2, 1) forces the min to be odd and the extent * to be even. The region computed always contains the region that * would have been computed without this directive, so no - * assertions are injected. */ + * assertions are injected. + */ Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0); + /** Expand the region computed so that the extent is a + * multiple of 'modulus'. For example, f.align_extent(x, 2) forces + * the extent realized to be even. The region computed always contains the + * region that would have been computed without this directive, so no + * assertions are injected. (This is essentially equivalent to align_bounds(), + * but always leaving the min untouched.) + */ + Func &align_extent(const Var &var, Expr modulus); + /** Bound the extent of a Func's realization, but not its * min. This means the dimension can be unrolled or vectorized * even when its min is not fixed (for example because it is diff --git a/src/Generator.h b/src/Generator.h index 5cbdcf0a3ca0..ba6e146e31d1 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2188,6 +2188,7 @@ class GeneratorOutputBase : public GIOBase { // @{ HALIDE_FORWARD_METHOD(Func, add_trace_tag) HALIDE_FORWARD_METHOD(Func, align_bounds) + HALIDE_FORWARD_METHOD(Func, align_extent) HALIDE_FORWARD_METHOD(Func, align_storage) HALIDE_FORWARD_METHOD_CONST(Func, args) HALIDE_FORWARD_METHOD(Func, bound) diff --git a/src/Schedule.h b/src/Schedule.h index 67e36c0503b1..935aa984db7a 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -432,7 +432,8 @@ struct Bound { /** If defined, the number of iterations will be a multiple of * "modulus", and the first iteration will be at a value congruent - * to "remainder" modulo "modulus". Set by Func::align_bounds. */ + * to "remainder" modulo "modulus". Set by Func::align_bounds and + * Func::align_extent. */ Expr modulus, remainder; }; @@ -557,7 +558,7 @@ class FuncSchedule { /** You may explicitly bound some of the dimensions of a function, * or constrain them to lie on multiples of a given factor. See - * \ref Func::bound and \ref Func::align_bounds */ + * \ref Func::bound and \ref Func::align_bounds and \ref Func::align_extent. */ // @{ const std::vector &bounds() const; std::vector &bounds(); diff --git a/test/correctness/align_bounds.cpp b/test/correctness/align_bounds.cpp index 44ef270bdd36..5eaf80c811ce 100644 --- a/test/correctness/align_bounds.cpp +++ b/test/correctness/align_bounds.cpp @@ -146,6 +146,57 @@ int main(int argc, char **argv) { } } + // Now try a case where we align the extent but not the min. + { + Func f, g, h; + Var x; + + f(x) = 3; + + g(x) = select(x % 2 == 0, f(x + 1), f(x - 1) + 8); + + Param p; + h(x) = g(x - p) + g(x + p); + + f.compute_root(); + g.compute_root().align_extent(x, 32).trace_realizations(); + + p.set(3); + h.set_custom_trace(my_trace); + Buffer result = h.realize({10}); + + for (int i = 0; i < 10; i++) { + int correct = (i & 1) == 1 ? 6 : 22; + if (result(i) != correct) { + printf("result(%d) = %d instead of %d\n", + i, result(i), correct); + return -1; + } + } + + // Now the min/max should stick to odd numbers + if (trace_min != -3 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + + // Increasing p by one should have no effect + p.set(4); + h.realize(result); + if (trace_min != -4 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + + // But increasing it again should cause a jump of two in the bounds computed. + p.set(5); + h.realize(result); + if (trace_min != -5 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + } + printf("Success!\n"); return 0; }