diff --git a/python_bindings/src/PyFunc.cpp b/python_bindings/src/PyFunc.cpp index a50f5a719cf1..a6c05cc1dd02 100644 --- a/python_bindings/src/PyFunc.cpp +++ b/python_bindings/src/PyFunc.cpp @@ -325,6 +325,7 @@ void define_func(py::module &m) { .def("set_estimates", &Func::set_estimates, py::arg("estimates")) .def("align_bounds", &Func::align_bounds, py::arg("var"), py::arg("modulus"), py::arg("remainder") = 0) + .def("align_extent", &Func::align_extent, py::arg("var"), py::arg("modulus")) .def("bound_extent", &Func::bound_extent, py::arg("var"), py::arg("extent")) diff --git a/src/AllocationBoundsInference.cpp b/src/AllocationBoundsInference.cpp index accb9b921d2b..7ca63e1ed262 100644 --- a/src/AllocationBoundsInference.cpp +++ b/src/AllocationBoundsInference.cpp @@ -83,16 +83,20 @@ class AllocationInference : public IRMutator { extent = simplify((max - min) + 1); } if (bound.modulus.defined()) { - internal_assert(bound.remainder.defined()); - min -= bound.remainder; - min = (min / bound.modulus) * bound.modulus; - min += bound.remainder; - Expr max_plus_one = max + 1; - max_plus_one -= bound.remainder; - max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; - max_plus_one += bound.remainder; - extent = simplify(max_plus_one - min); - max = max_plus_one - 1; + if (bound.remainder.defined()) { + min -= bound.remainder; + min = (min / bound.modulus) * bound.modulus; + min += bound.remainder; + Expr max_plus_one = max + 1; + max_plus_one -= bound.remainder; + max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; + max_plus_one += bound.remainder; + extent = simplify(max_plus_one - min); + max = max_plus_one - 1; + } else { + extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus); + max = simplify(min + extent - 1); + } } Expr min_var = Variable::make(Int(32), min_name); diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 2f10b2193542..99848023e0f9 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -578,14 +578,20 @@ class BoundsInference : public IRMutator { } if (bound.modulus.defined()) { - min_required -= bound.remainder; - min_required = (min_required / bound.modulus) * bound.modulus; - min_required += bound.remainder; - Expr max_plus_one = max_required + 1; - max_plus_one -= bound.remainder; - max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; - max_plus_one += bound.remainder; - max_required = max_plus_one - 1; + if (bound.remainder.defined()) { + min_required -= bound.remainder; + min_required = (min_required / bound.modulus) * bound.modulus; + min_required += bound.remainder; + Expr max_plus_one = max_required + 1; + max_plus_one -= bound.remainder; + max_plus_one = ((max_plus_one + bound.modulus - 1) / bound.modulus) * bound.modulus; + max_plus_one += bound.remainder; + max_required = max_plus_one - 1; + } else { + Expr extent = (max_required - min_required) + 1; + extent = simplify(((extent + bound.modulus - 1) / bound.modulus) * bound.modulus); + max_required = simplify(min_required + extent - 1); + } s = LetStmt::make(min_var, min_required, s); s = LetStmt::make(max_var, max_required, s); } diff --git a/src/Func.cpp b/src/Func.cpp index a2df2e3adec8..18c98175d019 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -2264,7 +2264,6 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) { // Reduce the remainder remainder = remainder % modulus; - invalidate_cache(); bool found = func.is_pure_arg(var.name()); @@ -2279,6 +2278,26 @@ Func &Func::align_bounds(const Var &var, Expr modulus, Expr remainder) { return *this; } +Func &Func::align_extent(const Var &var, Expr modulus) { + user_assert(modulus.defined()) << "modulus is undefined\n"; + user_assert(Int(32).can_represent(modulus.type())) << "Can't represent modulus as int32\n"; + + modulus = cast(modulus); + + invalidate_cache(); + + bool found = func.is_pure_arg(var.name()); + user_assert(found) + << "Can't align extent of variable " << var.name() + << " of function " << name() + << " because " << var.name() + << " is not one of the pure variables of " << name() << ".\n"; + + Bound b = {var.name(), Expr(), Expr(), modulus, Expr()}; + func.schedule().bounds().push_back(b); + return *this; +} + Func &Func::tile(const VarOrRVar &x, const VarOrRVar &y, const VarOrRVar &xo, const VarOrRVar &yo, const VarOrRVar &xi, const VarOrRVar &yi, diff --git a/src/Func.h b/src/Func.h index 34247d0811f8..4602a033923c 100644 --- a/src/Func.h +++ b/src/Func.h @@ -1551,9 +1551,19 @@ class Func { * f.align_bounds(x, 2, 1) forces the min to be odd and the extent * to be even. The region computed always contains the region that * would have been computed without this directive, so no - * assertions are injected. */ + * assertions are injected. + */ Func &align_bounds(const Var &var, Expr modulus, Expr remainder = 0); + /** Expand the region computed so that the extent is a + * multiple of 'modulus'. For example, f.align_extent(x, 2) forces + * the extent realized to be even. The region computed always contains the + * region that would have been computed without this directive, so no + * assertions are injected. (This is essentially equivalent to align_bounds(), + * but always leaving the min untouched.) + */ + Func &align_extent(const Var &var, Expr modulus); + /** Bound the extent of a Func's realization, but not its * min. This means the dimension can be unrolled or vectorized * even when its min is not fixed (for example because it is diff --git a/src/Generator.h b/src/Generator.h index 5cbdcf0a3ca0..ba6e146e31d1 100644 --- a/src/Generator.h +++ b/src/Generator.h @@ -2188,6 +2188,7 @@ class GeneratorOutputBase : public GIOBase { // @{ HALIDE_FORWARD_METHOD(Func, add_trace_tag) HALIDE_FORWARD_METHOD(Func, align_bounds) + HALIDE_FORWARD_METHOD(Func, align_extent) HALIDE_FORWARD_METHOD(Func, align_storage) HALIDE_FORWARD_METHOD_CONST(Func, args) HALIDE_FORWARD_METHOD(Func, bound) diff --git a/src/Schedule.h b/src/Schedule.h index 67e36c0503b1..935aa984db7a 100644 --- a/src/Schedule.h +++ b/src/Schedule.h @@ -432,7 +432,8 @@ struct Bound { /** If defined, the number of iterations will be a multiple of * "modulus", and the first iteration will be at a value congruent - * to "remainder" modulo "modulus". Set by Func::align_bounds. */ + * to "remainder" modulo "modulus". Set by Func::align_bounds and + * Func::align_extent. */ Expr modulus, remainder; }; @@ -557,7 +558,7 @@ class FuncSchedule { /** You may explicitly bound some of the dimensions of a function, * or constrain them to lie on multiples of a given factor. See - * \ref Func::bound and \ref Func::align_bounds */ + * \ref Func::bound and \ref Func::align_bounds and \ref Func::align_extent. */ // @{ const std::vector &bounds() const; std::vector &bounds(); diff --git a/test/correctness/align_bounds.cpp b/test/correctness/align_bounds.cpp index 44ef270bdd36..5eaf80c811ce 100644 --- a/test/correctness/align_bounds.cpp +++ b/test/correctness/align_bounds.cpp @@ -146,6 +146,57 @@ int main(int argc, char **argv) { } } + // Now try a case where we align the extent but not the min. + { + Func f, g, h; + Var x; + + f(x) = 3; + + g(x) = select(x % 2 == 0, f(x + 1), f(x - 1) + 8); + + Param p; + h(x) = g(x - p) + g(x + p); + + f.compute_root(); + g.compute_root().align_extent(x, 32).trace_realizations(); + + p.set(3); + h.set_custom_trace(my_trace); + Buffer result = h.realize({10}); + + for (int i = 0; i < 10; i++) { + int correct = (i & 1) == 1 ? 6 : 22; + if (result(i) != correct) { + printf("result(%d) = %d instead of %d\n", + i, result(i), correct); + return -1; + } + } + + // Now the min/max should stick to odd numbers + if (trace_min != -3 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + + // Increasing p by one should have no effect + p.set(4); + h.realize(result); + if (trace_min != -4 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + + // But increasing it again should cause a jump of two in the bounds computed. + p.set(5); + h.realize(result); + if (trace_min != -5 || trace_extent != 32) { + printf("%d: Wrong bounds: [%d, %d]\n", __LINE__, trace_min, trace_extent); + return -1; + } + } + printf("Success!\n"); return 0; }