From cc03c9c0b4bdd45012c6eac437221f5b4cb8436a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Fri, 8 Jan 2021 16:28:07 -0800 Subject: [PATCH] Prototype of multiple scattering update definitions (#5553) Add "gather" and "scatter" intrinsics, which let you write update definitions which store multiple values at once to different computed locations. Useful for doing things like swapping or permuting elements in-place. See comments in IROperator.h for more details. --- src/Bounds.cpp | 8 + src/CSE.cpp | 2 +- src/IR.cpp | 1 + src/IR.h | 1 + src/IROperator.cpp | 19 ++ src/IROperator.h | 79 ++++++++ src/Simplify_And.cpp | 1 + src/Simplify_Stmts.cpp | 42 ++++- src/SplitTuples.cpp | 199 +++++++++++++++++++- test/correctness/CMakeLists.txt | 1 + test/correctness/multiple_scatter.cpp | 250 ++++++++++++++++++++++++++ 11 files changed, 599 insertions(+), 4 deletions(-) create mode 100644 test/correctness/multiple_scatter.cpp diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 3749f5006c59..ce42e91f1040 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1425,6 +1425,14 @@ class Bounds : public IRVisitor { } else if (op->is_intrinsic(Call::memoize_expr)) { internal_assert(!op->args.empty()); op->args[0].accept(this); + } else if (op->is_intrinsic(Call::scatter_gather)) { + // Take the union of the args + Interval result = Interval::nothing(); + for (const Expr &e : op->args) { + e.accept(this); + result.include(interval); + } + interval = result; } else if (op->call_type == Call::Halide) { bounds_of_func(op->name, op->value_index, op->type); } else { diff --git a/src/CSE.cpp b/src/CSE.cpp index 552cef94d727..d57d1c18cfaa 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -314,7 +314,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { // Wrap the final expr in the lets. for (size_t i = lets.size(); i > 0; i--) { Expr value = lets[i - 1].second; - // Drop this variable as an acceptible replacement for this expr. + // Drop this variable as an acceptable replacement for this expr. replacer.erase(value); // Use containing lets in the value. value = replacer.mutate(lets[i - 1].second); diff --git a/src/IR.cpp b/src/IR.cpp index e1fa4b39de3c..edc293b6ce71 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -626,6 +626,7 @@ const char *const intrinsic_op_names[] = { "require_mask", "return_second", "rewrite_buffer", + "scatter_gather", "select_mask", "shift_left", "shift_right", diff --git a/src/IR.h b/src/IR.h index d91e57cbd681..ce45483882e6 100644 --- a/src/IR.h +++ b/src/IR.h @@ -538,6 +538,7 @@ struct Call : public ExprNode { require_mask, return_second, rewrite_buffer, + scatter_gather, select_mask, shift_left, shift_right, diff --git a/src/IROperator.cpp b/src/IROperator.cpp index dcdad0285caf..19812f163a09 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2361,4 +2361,23 @@ Expr undef(Type t) { Internal::Call::PureIntrinsic); } +namespace { +Expr make_scatter_gather(const std::vector &args) { + // There's currently no difference in the IR between a gather and + // a scatter. They're distinct just to make code more readable. + return Halide::Internal::Call::make(args[0].type(), + Halide::Internal::Call::scatter_gather, + args, + Halide::Internal::Call::PureIntrinsic); +} +} // namespace + +Expr scatter(const std::vector &args) { + return make_scatter_gather(args); +} + +Expr gather(const std::vector &args) { + return make_scatter_gather(args); +} + } // namespace Halide diff --git a/src/IROperator.h b/src/IROperator.h index fc2dffffcc7a..be39fb321d8c 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1398,6 +1398,85 @@ namespace Internal { Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max); } // namespace Internal +/** Scatter and gather are used for update definition which must store + * multiple values to distinct locations at the same time. The + * multiple expressions on the right-hand-side are bundled together + * into a "gather", which must match a "scatter" the the same number + * of arguments on the left-hand-size. For example, to store the + * values 1 and 2 to the locations (x, y, 3) and (x, y, 4), + * respectively: + * +\code +f(x, y, scatter(3, 4)) = gather(1, 2); +\endcode + * + * The result of gather or scatter can be treated as an + * expression. Any containing operations on it can be assumed to + * distribute over the elements. If two gather expressions are + * combined with an arithmetic operator (e.g. added), they combine + * element-wise. The following example stores the values 2 * x, 2 * y, + * and 2 * c to the locations (x + 1, y, c), (x, y + 3, c), and (x, y, + * c + 2) respectively: + * +\code +f(x + scatter(1, 0, 0), y + scatter(0, 3, 0), c + scatter(0, 0, 2)) = 2 * gather(x, y, c); +\endcode +* +* Repeated values in the scatter cause multiple stores to the same +* location. The stores happen in order from left to right, so the +* rightmost value wins. The following code is equivalent to f(x) = 5 +* +\code +f(scatter(x, x)) = gather(3, 5); +\endcode +* +* Gathers are most useful for algorithms which require in-place +* swapping or permutation of multiple elements, or other kinds of +* in-place mutations that require loading multiple inputs, doing some +* operations to them jointly, then storing them again. The following +* update definition swaps the values of f at locations 3 and 5 if an +* input parameter p is true: +* +\code +f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5))); +\endcode +* +* For more examples of the use of scatter and gather, see +* test/correctness/multiple_scatter.cpp +* +* It is not currently possible to use scatter and gather to write an +* update definition in which the *number* of values loaded or stored +* varies, as the size of the scatter/gather packet must be fixed a +* compile-time. A workaround is to make the unwanted extra operations +* a redundant copy of the last operation, which will be +* dead-code-eliminated by the compiler. For example, the following +* update definition swaps the values at locations 3 and 5 when the +* parameter p is true, and rotates the values at locations 1, 2, and 3 +* when it is false. The load from 3 and store to 5 will be redundantly +* repeated: +* +\code +f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1))); +\endcode +* +* Note that in the p == true case, we redudantly load from 3 and write +* to 5 twice. +*/ +//@{ +Expr scatter(const std::vector &args); +Expr gather(const std::vector &args); + +template +Expr scatter(const Expr &e, Args &&... args) { + return scatter({e, std::forward(args)...}); +} + +template +Expr gather(const Expr &e, Args &&... args) { + return gather({e, std::forward(args)...}); +} +// @} + } // namespace Halide #endif diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index 1b35fda68340..31975da9b9cb 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -57,6 +57,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) { rewrite(!x && x, false) || rewrite(y <= x && x < y, false) || rewrite(x != c0 && x == c1, b, c0 != c1) || + rewrite(x == c0 && x == c1, false, c0 != c1) || // Note: In the predicate below, if undefined overflow // occurs, the predicate counts as false. If well-defined // overflow occurs, the condition couldn't possibly diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 7e4b40d2905b..117d3112ec19 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -1,5 +1,6 @@ #include "Simplify_Internal.h" +#include "ExprUsesVar.h" #include "IRMutator.h" #include "Substitute.h" @@ -388,6 +389,9 @@ Stmt Simplify::visit(const Block *op) { rest.as() ? rest.as() : (block_rest ? block_rest->first.as() : nullptr); Stmt if_rest = block_rest ? block_rest->rest : Stmt(); + const Store *store_first = first.as(); + const Store *store_next = block_rest ? block_rest->first.as() : rest.as(); + if (is_no_op(first) && is_no_op(rest)) { return Evaluate::make(0); @@ -411,11 +415,28 @@ Stmt Simplify::visit(const Block *op) { new_block = substitute(let_rest->name, new_var, new_block); return LetStmt::make(var_name, let_first->value, new_block); + } else if (store_first && + store_next && + store_first->name == store_next->name && + equal(store_first->index, store_next->index) && + equal(store_first->predicate, store_next->predicate) && + is_pure(store_first->index) && + is_pure(store_first->value) && + is_pure(store_first->predicate) && + !expr_uses_var(store_next->index, store_next->name) && + !expr_uses_var(store_next->value, store_next->name) && + !expr_uses_var(store_next->predicate, store_next->name)) { + // Second store clobbers first + if (block_rest) { + return Block::make(store_next, block_rest->rest); + } else { + return store_next; + } } else if (if_first && if_next && equal(if_first->condition, if_next->condition) && is_pure(if_first->condition)) { - // Two ifs with matching conditions + // Two ifs with matching conditions. Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case)); Stmt else_case; if (if_first->else_case.defined() && if_next->else_case.defined()) { @@ -448,6 +469,25 @@ Stmt Simplify::visit(const Block *op) { result = Block::make(result, if_rest); } return result; + } else if (if_first && + if_next && + is_pure(if_first->condition) && + is_pure(if_next->condition) && + is_const_one(mutate(!(if_first->condition && if_next->condition), nullptr))) { + // Two ifs where the first condition being true implies the + // second is false. The second if can be nested inside the + // else case of the first one, turning a block of if + // statements into an if-else chain. + Stmt then_case = if_first->then_case; + Stmt else_case = if_next; + if (if_first->else_case.defined()) { + else_case = Block::make(if_first->else_case, else_case); + } + Stmt result = mutate(IfThenElse::make(if_first->condition, then_case, else_case)); + if (if_rest.defined()) { + result = Block::make(result, if_rest); + } + return result; } else if (op->first.same_as(first) && op->rest.same_as(rest)) { return op; diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index dd56a415cf29..cc3745b41579 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -1,11 +1,13 @@ #include "SplitTuples.h" #include "Bounds.h" +#include "CSE.h" #include "ExprUsesVar.h" #include "Function.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" +#include "Substitute.h" namespace Halide { namespace Internal { @@ -335,10 +337,203 @@ class SplitTuples : public IRMutator { } }; +class SplitScatterGather : public IRMutator { + using IRMutator::visit; + + // The enclosing producer node. Used for error messages. + const ProducerConsumer *producer = nullptr; + + class GetScatterGatherSize : public IRVisitor { + bool permitted = true; + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->is_intrinsic(Call::scatter_gather)) { + user_assert(permitted) + << "Can't nest an expression tuple inside another in definition of " + << producer_name; + if (result == 0) { + result = (int)op->args.size(); + } else { + user_assert((int)op->args.size() == result) + << "Expression tuples of mismatched sizes used in definition of " + << producer_name << ": " << result << " vs " << op->args.size(); + } + // No nesting tuples + permitted = false; + IRVisitor::visit(op); + permitted = true; + } else { + IRVisitor::visit(op); + } + } + + public: + string producer_name; + int result = 0; + } get_scatter_gather_size_visitor; + + int get_scatter_gather_size(const IRNode *op) { + if (producer) { + get_scatter_gather_size_visitor.producer_name = producer->name; + } + get_scatter_gather_size_visitor.result = 0; + op->accept(&get_scatter_gather_size_visitor); + + // Maybe the user did something like jam a gather op into a + // constraint or tile size, so this is a user_assert. + user_assert(producer || get_scatter_gather_size_visitor.result == 0) + << "scatter/gather expression used outside of a Func definition"; + + return get_scatter_gather_size_visitor.result; + } + + class ExtractScatterGatherElement : public IRMutator { + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::scatter_gather)) { + // No need to recursively mutate because we've + // already asserted that these aren't nested. + internal_assert(idx < (int)op->args.size()); + return op->args[idx]; + } else { + return IRMutator::visit(op); + } + } + + public: + int idx; + } extractor; + + Stmt visit(const ProducerConsumer *op) override { + ScopedValue old(producer, op->is_producer ? op : producer); + return IRMutator::visit(op); + } + + Stmt visit_gather_let_stmt(const LetStmt *op, int size) { + // Split this variable into the gather components + vector> lets; + vector vars; + for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { + string name = unique_name(op->name + "." + std::to_string(extractor.idx)); + lets.emplace_back(name, extractor.mutate(op->value)); + vars.push_back(Variable::make(op->value.type(), name)); + } + + Stmt body = op->body; + Expr gather_replacement = Call::make(op->value.type(), + Call::scatter_gather, + vars, + Call::PureIntrinsic); + body = substitute(op->name, gather_replacement, body); + body = mutate(body); + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + body = LetStmt::make(it->first, it->second, body); + } + + return body; + } + + Stmt visit(const LetStmt *op) override { + vector> lets; + int size = 0; + Stmt body; + do { + body = op->body; + size = get_scatter_gather_size(op->value.get()); + if (size != 0) { + break; + } + lets.emplace_back(op->name, op->value); + op = body.as(); + } while (op); + + if (size) { + internal_assert(op); + body = visit_gather_let_stmt(op, size); + } else { + internal_assert(op == nullptr); + body = mutate(body); + } + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + body = LetStmt::make(it->first, it->second, body); + } + + return body; + } + + Stmt visit(const Provide *op) override { + int size = get_scatter_gather_size(op); + + if (size == 0) { + return IRMutator::visit(op); + } + + // The LHS should contain at least one scatter op, or our scatters + // all go to the same place. Is it worth asserting this? It + // could be a bug, or it could be some sort of degenerate base case. + + // Fork the args and the RHS into their various versions + vector provides; + vector names; + vector exprs; + for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { + vector args = op->args; + for (Expr &a : args) { + string name = unique_name('t'); + exprs.push_back(extractor.mutate(a)); + names.push_back(name); + a = Variable::make(a.type(), name); + } + vector values = op->values; + for (Expr &v : values) { + v = extractor.mutate(v); + string name = unique_name('t'); + exprs.push_back(extractor.mutate(v)); + names.push_back(name); + v = Variable::make(v.type(), name); + } + provides.push_back(Provide::make(op->name, values, args)); + } + + Stmt s = Block::make(provides); + // We just duplicated all the non-scatter/gather stuff too, + // so do joint CSE on the exprs + Expr bundle = Call::make(Int(32), Call::bundle, exprs, Call::PureIntrinsic); + bundle = common_subexpression_elimination(bundle); + + vector> lets; + while (const Let *let = bundle.as()) { + lets.emplace_back(let->name, let->value); + bundle = let->body; + } + const Call *c = bundle.as(); + internal_assert(c && c->is_intrinsic(Call::bundle)); + for (size_t i = 0; i < exprs.size(); i++) { + if (is_pure(c->args[i])) { + // names[i] is only used once, so if the value is pure + // it should be substituted in + s = substitute(names[i], c->args[i], s); + } else { + lets.emplace_back(names[i], c->args[i]); + } + } + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + s = LetStmt::make(it->first, it->second, s); + } + + return s; + } +}; + } // namespace -Stmt split_tuples(const Stmt &s, const map &env) { - return SplitTuples(env).mutate(s); +Stmt split_tuples(const Stmt &stmt, const map &env) { + Stmt s = SplitTuples(env).mutate(stmt); + s = SplitScatterGather().mutate(s); + return s; } } // namespace Internal diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index f0a0fc28238e..56dca979b9e5 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -216,6 +216,7 @@ tests(GROUPS correctness multipass_constraints.cpp multiple_outputs.cpp multiple_outputs_extern.cpp + multiple_scatter.cpp mux.cpp named_updates.cpp nested_shiftinwards.cpp diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp new file mode 100644 index 000000000000..e05ae75905a9 --- /dev/null +++ b/test/correctness/multiple_scatter.cpp @@ -0,0 +1,250 @@ +#include "Halide.h" + +using namespace Halide; + +using std::vector; + +int main(int argc, char **argv) { + // Implement a sorting network using update definitions that write to multiple outputs + + // The links in the sorting network. Sorts 8 things. + int network_[19][2] = + {{0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + {0, 2}, + {1, 3}, + {4, 6}, + {5, 7}, + {1, 2}, + {5, 6}, + {0, 4}, + {3, 7}, + {1, 5}, + {2, 6}, + {1, 4}, + {3, 6}, + {2, 4}, + {3, 5}, + {3, 4}}; + + Buffer network(&network_[0][0], 2, 19); + + Buffer input(128, 8); + + input.fill(std::mt19937{0}); + + Func sorted1; + Var x, y; + + // Run the sorting network with an RDom over the links + sorted1(x, y) = input(x, y); + RDom r(0, network.dim(1).extent()); + + // We know that the network we'll use caps out at 7, but the + // compiler doesn't know that because it's coming from an input + // buffer, so use unsafe_promise_clamped. + Expr min_idx = unsafe_promise_clamped(network(0, r), 0, 7); + Expr max_idx = unsafe_promise_clamped(network(1, r), 0, 7); + + sorted1(x, scatter(min_idx, max_idx)) = + gather(min(sorted1(x, min_idx), sorted1(x, max_idx)), + max(sorted1(x, min_idx), sorted1(x, max_idx))); + + sorted1.vectorize(x, 8).update().vectorize(x, 8); + + Buffer output1(128, 8), output2(128, 8); + sorted1.realize(output1); + + // Run the sorting network fully unrolled as a single big multi-scatter + Func sorted2; + sorted2(x, y) = input(x, y); + + vector lhs, rhs; + for (int i = 0; i < 8; i++) { + lhs.emplace_back(i); + rhs.emplace_back(sorted2(x, i)); + } + + for (int l = 0; l < network.dim(1).extent(); l++) { + int min_idx = network(0, l); + int max_idx = network(1, l); + Expr tmp = rhs[min_idx]; + // We're going to be asking a lot of CSE + rhs[min_idx] = min(rhs[min_idx], rhs[max_idx]); + rhs[max_idx] = max(tmp, rhs[max_idx]); + } + + sorted2(x, scatter(lhs)) = gather(rhs); + sorted2.vectorize(x, 8).update().vectorize(x, 8); + + sorted2.realize(output2); + + for (int i = 0; i < output1.dim(0).extent(); i++) { + vector correct(output1.dim(1).extent()); + for (int j = 0; j < output1.dim(1).extent(); j++) { + correct[j] = input(i, j); + } + std::sort(correct.begin(), correct.end()); + for (int j = 0; j < output1.dim(1).extent(); j++) { + if (output1(i, j) != correct[j]) { + printf("output1(%d, %d) = %d instead of %d\n", i, j, output1(i, j), correct[j]); + return -1; + } + if (output2(i, j) != correct[j]) { + printf("output2(%d, %d) = %d instead of %d\n", i, j, output2(i, j), correct[j]); + return -1; + } + } + } + + { + // An update definitions that rotates a square region in-place. + + const int sz = 17; + Buffer input(sz, sz); + std::mt19937 rng; + input.fill([&](int x, int y) { return (uint8_t)(rng() & 0xff); }); + + Func rot; + rot(x, y) = input(x, y); + + RDom r(0, (sz + 1) / 2, 0, sz / 2); + + vector src_x{r.x, sz - 1 - r.y, sz - 1 - r.x, r.y}; + vector src_y{r.y, r.x, sz - 1 - r.y, sz - 1 - r.x}; + vector dst_x = src_x, dst_y = src_y; + + std::rotate(dst_x.begin(), dst_x.begin() + 1, dst_x.end()); + std::rotate(dst_y.begin(), dst_y.begin() + 1, dst_y.end()); + + rot(scatter(dst_x), scatter(dst_y)) = + rot(gather(src_x), gather(src_y)); + + Buffer output = rot.realize(sz, sz); + + for (int y = 0; y < sz; y++) { + for (int x = 0; x < sz; x++) { + int correct = input(y, sz - 1 - x); + if (output(x, y) != correct) { + printf("output(%d, %d) = %d instead of %d\n", + x, y, output(x, y), correct); + return -1; + } + } + } + } + + { + // Atomic complex multiplication modulo 2^256 where the complex numbers + // are a dimension of the Func rather than a tuple + + Buffer input(2, 100); + std::mt19937 rng; + input.fill([&](int x, int y) { return (uint8_t)(rng() & 0xff); }); + + Func prod; + Var x; + RDom r(0, input.dim(1).extent()); + prod(x) = cast(mux(x, {1, 0})); + prod(scatter(0, 1)) = + gather(prod(0) * input(0, r) - prod(1) * input(1, r), + prod(0) * input(1, r) + prod(1) * input(0, r)); + + // TODO: We don't currently recognize this as an + // associative update, so for now we force it by passing + // 'true' to atomic(). + prod.update().atomic(true).parallel(r); + + Buffer result = prod.realize(2); + + uint8_t correct_re = 1, correct_im = 0; + for (int i = 0; i < input.dim(1).extent(); i++) { + int new_re = correct_re * input(0, i) - correct_im * input(1, i); + int new_im = correct_re * input(1, i) + correct_im * input(0, i); + correct_re = new_re; + correct_im = new_im; + } + + if (correct_re != result(0) || correct_im != result(1)) { + printf("Complex multiplication reduction produced wrong result: \n" + "Got %d + %di instead of %d + %di\n", + result(0), result(1), correct_re, correct_im); + } + } + + { + // Lexicographic bubble sort on tuples + Func f; + Var x, y; + + f(x) = {13 - (x % 10), cast(x * 17)}; + + RDom r(0, 99, 0, 99); + r.where(r.x < 99 - r.y); + + Expr should_swap = (f(r.x)[0] > f(r.x + 1)[0] || + (f(r.x)[0] == f(r.x + 1)[0] && + f(r.x)[1] > f(r.x + 1)[1])); + r.where(should_swap); + + // Swap elements that satisfy the RDom predicate + f(scatter(r.x, r.x + 1)) = f(gather(r.x + 1, r.x)); + + Buffer out_0(100); + Buffer out_1(100); + f.realize({out_0, out_1}); + + for (int i = 0; i < 99; i++) { + bool check = (out_0(i) < out_0(i + 1) || + (out_0(i) == out_0(i + 1) && out_1(i) < out_1(i + 1))); + if (!check) { + printf("Sort result is not correctly ordered at elements %d, %d:\n" + "(%d, %d) vs (%d, %d)\n", + i, i + 1, out_0(i), out_1(i), out_0(i + 1), out_1(i + 1)); + return -1; + } + } + } + + { + // A scatter can exist without a gather if you're just broadcasting + Func f; + Var x; + f(x) = 0; + f(scatter(0, 1, 2, 3)) = 5; + + Buffer out = f.realize(5); + for (int i = 0; i < 5; i++) { + int correct = i < 4 ? 5 : 0; + if (out(i) != correct) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct); + return -1; + } + } + } + + { + // A gather can exist without a scatter, but it's sort of + // silly because last element wins. It's not outright + // disallowed because it may be a degenerate case of some + // generic code. + Func f; + Var x; + f(x) = 0; + f(3) = gather(1, 9); + + Buffer out = f.realize(5); + for (int i = 0; i < 5; i++) { + int correct = i == 3 ? 9 : 0; + if (out(i) != correct) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct); + return -1; + } + } + } + + printf("Success!\n"); + return 0; +}