From 458efbd7b24f1e6e48792cc4c6f01213438b4552 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 12 Dec 2020 12:09:18 -0800 Subject: [PATCH 01/16] Prototype of multiple scattering update definitions --- src/Bounds.cpp | 4 ++ src/IR.cpp | 1 + src/IR.h | 1 + src/SplitTuples.cpp | 122 +++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 126 insertions(+), 2 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index ba1a51d3c23b..fdc887e86e21 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1423,6 +1423,10 @@ class Bounds : public IRVisitor { } else if (op->is_intrinsic(Call::memoize_expr)) { internal_assert(!op->args.empty()); op->args[0].accept(this); + } else if (op->is_intrinsic(Call::tuple)) { + // A tuple could evaluate to any one of the args. The base + // class visitor is fine as it takes a union. + IRVisitor::visit(op); } else if (op->call_type == Call::Halide) { bounds_of_func(op->name, op->value_index, op->type); } else { diff --git a/src/IR.cpp b/src/IR.cpp index a9b5ffb6bf37..eb773e1b520b 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -634,6 +634,7 @@ const char *const intrinsic_op_names[] = { "sorted_avg", "strict_float", "stringify", + "tuple", "undef", "unsafe_promise_clamped", }; diff --git a/src/IR.h b/src/IR.h index 85ff0e7775d6..1706a5eff637 100644 --- a/src/IR.h +++ b/src/IR.h @@ -546,6 +546,7 @@ struct Call : public ExprNode { sorted_avg, // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1]. strict_float, stringify, + tuple, undef, unsafe_promise_clamped, IntrinsicOpCount // Sentinel: keep last. diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index dd56a415cf29..9b949a312878 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -1,11 +1,13 @@ #include "SplitTuples.h" #include "Bounds.h" +#include "CSE.h" #include "ExprUsesVar.h" #include "Function.h" #include "IRMutator.h" #include "IROperator.h" #include "Simplify.h" +#include "Substitute.h" namespace Halide { namespace Internal { @@ -335,10 +337,126 @@ class SplitTuples : public IRMutator { } }; +class SplitTupleExprs : public IRMutator { + using IRMutator::visit; + + // TODO: worry about LetStmts that have tuple intrinsics in the RHS + + Stmt visit(const Provide *op) override { + class GetTupleSize : public IRVisitor { + bool permitted = true; + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->is_intrinsic(Call::tuple)) { + user_assert(permitted) + << "Can't nest an expression tuple inside another in definition of " + << op->name << "\n"; + if (result == 0) { + result = (int)op->args.size(); + } else { + user_assert((int)op->args.size() == result) + << "Expression tuples of mismatched sizes used in definition of " + << op->name << ": " << result << " vs " << op->args.size(); + } + // No nesting tuples + permitted = false; + IRVisitor::visit(op); + permitted = true; + } else { + IRVisitor::visit(op); + } + } + + public: + int result = 0; + } get_tuple_size; + + op->accept(&get_tuple_size); + int size = get_tuple_size.result; + + if (size == 0) { + return IRMutator::visit(op); + } + + // The LHS should contain at least one tuple, or our scatters + // all go to the same place. Is it worth asserting this? It + // could be a bug, or it could be some sort of degenerate base case. + + // Fork the args and the RHS into their various versions + class ExtractTupleElement : public IRMutator { + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::tuple)) { + // No need to recursively mutate because we've + // already asserted that these aren't nested. + internal_assert(idx < (int)op->args.size()); + return op->args[idx]; + } else { + return IRMutator::visit(op); + } + } + + public: + int idx; + } extractor; + + vector provides; + vector names; + vector rhs_values; + for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { + vector args = op->args; + for (Expr &a : args) { + a = extractor.mutate(a); + } + vector values = op->values; + for (Expr &v : values) { + v = extractor.mutate(v); + string name = unique_name('t'); + rhs_values.push_back(extractor.mutate(v)); + names.push_back(name); + v = Variable::make(v.type(), name); + } + provides.push_back(Provide::make(op->name, values, args)); + } + + Stmt s = Block::make(provides); + + // We just duplicated all the non-tuple stuff on the RHS too, + // so do joint CSE on the rhs_values + Expr bundle = Call::make(Int(32), Call::bundle, rhs_values, Call::PureIntrinsic); + bundle = common_subexpression_elimination(bundle); + + vector> lets; + while (const Let *let = bundle.as()) { + lets.emplace_back(let->name, let->value); + bundle = let->body; + } + const Call *c = bundle.as(); + internal_assert(c && c->is_intrinsic(Call::bundle)); + for (size_t i = 0; i < rhs_values.size(); i++) { + if (is_pure(c->args[i])) { + // names[i] is only used once, so if the value is pure + // it should be substituted in + s = substitute(names[i], c->args[i], s); + } else { + lets.emplace_back(names[i], c->args[i]); + } + } + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + s = LetStmt::make(it->first, it->second, s); + } + + return s; + } +}; + } // namespace -Stmt split_tuples(const Stmt &s, const map &env) { - return SplitTuples(env).mutate(s); +Stmt split_tuples(const Stmt &stmt, const map &env) { + Stmt s = SplitTuples(env).mutate(stmt); + s = SplitTupleExprs().mutate(s); + return s; } } // namespace Internal From 5f6eb57bcd6fa9f5914cb631694d9740874f7f35 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 12 Dec 2020 12:15:58 -0800 Subject: [PATCH 02/16] Add test --- test/correctness/CMakeLists.txt | 1 + test/correctness/multiple_scatter.cpp | 185 ++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 test/correctness/multiple_scatter.cpp diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index e1ba9d8e4eb3..e8e5401d5325 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -214,6 +214,7 @@ tests(GROUPS correctness multipass_constraints.cpp multiple_outputs.cpp multiple_outputs_extern.cpp + multiple_scatter.cpp mux.cpp named_updates.cpp nested_shiftinwards.cpp diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp new file mode 100644 index 000000000000..4e6c706fb199 --- /dev/null +++ b/test/correctness/multiple_scatter.cpp @@ -0,0 +1,185 @@ +#include "Halide.h" + +using namespace Halide; + +using std::vector; + +Expr make_expr_tuple(const vector &args) { + return Halide::Internal::Call::make(args[0].type(), Halide::Internal::Call::tuple, args, Halide::Internal::Call::PureIntrinsic); +} + +int main(int argc, char **argv) { + // Implement a sorting network using update definitions that write to multiple outputs + + // The links in the sorting network. Sorts 8 things. + int network_[19][2] = + {{0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + {0, 2}, + {1, 3}, + {4, 6}, + {5, 7}, + {1, 2}, + {5, 6}, + {0, 4}, + {3, 7}, + {1, 5}, + {2, 6}, + {1, 4}, + {3, 6}, + {2, 4}, + {3, 5}, + {3, 4}}; + + Buffer network(&network_[0][0], 2, 19); + + Buffer input(128, 8); + + input.fill(std::mt19937{0}); + + Func sorted1; + Var x, y; + + // Run the sorting network with an RDom over the links + sorted1(x, y) = input(x, y); + RDom r(0, network.dim(1).extent()); + Expr min_idx = unsafe_promise_clamped(network(0, r), 0, 7); + Expr max_idx = unsafe_promise_clamped(network(1, r), 0, 7); + Expr dst = make_expr_tuple({min_idx, max_idx}); + sorted1(x, dst) = + make_expr_tuple({min(sorted1(x, min_idx), sorted1(x, max_idx)), + max(sorted1(x, min_idx), sorted1(x, max_idx))}); + + sorted1.vectorize(x, 8).update().vectorize(x, 8); + + Buffer output1(128, 8), output2(128, 8); + sorted1.realize(output1); + + // Run the sorting network fully unrolled as a single big multi-scatter + Func sorted2; + sorted2(x, y) = input(x, y); + + vector lhs, rhs; + for (int i = 0; i < 8; i++) { + lhs.emplace_back(i); + rhs.emplace_back(sorted2(x, i)); + } + + for (int l = 0; l < network.dim(1).extent(); l++) { + int min_idx = network(0, l); + int max_idx = network(1, l); + Expr tmp = rhs[min_idx]; + // We're going to be asking a lot of CSE + rhs[min_idx] = min(rhs[min_idx], rhs[max_idx]); + rhs[max_idx] = max(tmp, rhs[max_idx]); + } + + sorted2(x, make_expr_tuple(lhs)) = make_expr_tuple(rhs); + sorted2.vectorize(x, 8).update().vectorize(x, 8); + + sorted2.realize(output2); + + for (int i = 0; i < output1.dim(0).extent(); i++) { + vector correct(output1.dim(1).extent()); + for (int j = 0; j < output1.dim(1).extent(); j++) { + correct[j] = input(i, j); + } + std::sort(correct.begin(), correct.end()); + for (int j = 0; j < output1.dim(1).extent(); j++) { + if (output1(i, j) != correct[j]) { + printf("output1(%d, %d) = %d instead of %d\n", i, j, output1(i, j), correct[j]); + return -1; + } + if (output2(i, j) != correct[j]) { + printf("output2(%d, %d) = %d instead of %d\n", i, j, output2(i, j), correct[j]); + return -1; + } + } + } + + { + // An update definitions that rotates a square region in-place. + + const int sz = 17; + Buffer input(sz, sz); + std::mt19937 rng; + input.fill([&](int x, int y) { return (uint8_t)(rng() & 0xff); }); + + Func rot; + rot(x, y) = input(x, y); + + RDom r(0, (sz + 1) / 2, 0, sz / 2); + + vector in{rot(r.x, r.y), + rot(sz - 1 - r.y, r.x), + rot(sz - 1 - r.x, sz - 1 - r.y), + rot(r.y, sz - 1 - r.x)}; + + vector src_x{r.x, sz - 1 - r.y, sz - 1 - r.x, r.y}; + vector src_y{r.y, r.x, sz - 1 - r.y, sz - 1 - r.x}; + vector dst_x = src_x, dst_y = src_y; + + std::rotate(dst_x.begin(), dst_x.begin() + 1, dst_x.end()); + std::rotate(dst_y.begin(), dst_y.begin() + 1, dst_y.end()); + + rot(make_expr_tuple(dst_x), make_expr_tuple(dst_y)) = + rot(make_expr_tuple(src_x), make_expr_tuple(src_y)); + + Buffer output = rot.realize(sz, sz); + + for (int y = 0; y < sz; y++) { + for (int x = 0; x < sz; x++) { + int correct = input(y, sz - 1 - x); + if (output(x, y) != correct) { + printf("output(%d, %d) = %d instead of %d\n", + x, y, output(x, y), correct); + return -1; + } + } + } + } + + { + // Atomic complex multiplication modulo 2^256 where the complex numbers + // are a dimension of the Func rather than a tuple + + Buffer input(2, 100); + std::mt19937 rng; + input.fill([&](int x, int y) { return (uint8_t)(rng() & 0xff); }); + + Func prod; + Var x; + RDom r(0, input.dim(1).extent()); + Expr lhs = make_expr_tuple({0, 1}); + prod(x) = cast(mux(x, {1, 0})); + prod(lhs) = make_expr_tuple( + {prod(0) * input(0, r) - prod(1) * input(1, r), + prod(0) * input(1, r) + prod(1) * input(0, r)}); + + // TODO: We don't currently recognize this as an + // associative update, so for now we force it by passing + // 'true' to atomic(). + prod.update().atomic(true).parallel(r); + + Buffer result = prod.realize(2); + + uint8_t correct_re = 1, correct_im = 0; + for (int i = 0; i < input.dim(1).extent(); i++) { + int new_re = correct_re * input(0, i) - correct_im * input(1, i); + int new_im = correct_re * input(1, i) + correct_im * input(0, i); + correct_re = new_re; + correct_im = new_im; + } + + if (correct_re != result(0) || correct_im != result(1)) { + printf("Complex multiplication reduction produced wrong result: \n" + "Got %d + %di instead of %d + %di\n", + result(0), result(1), correct_re, correct_im); + } + } + + printf("Success!\n"); + return 0; +} From 798656f99a493184e26ad5bc43d6a31e21b88935 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 12 Dec 2020 14:52:27 -0800 Subject: [PATCH 03/16] Handle tuples of tuples --- src/SplitTuples.cpp | 116 ++++++++++++++++---------- test/correctness/multiple_scatter.cpp | 37 ++++++++ 2 files changed, 110 insertions(+), 43 deletions(-) diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 9b949a312878..2dc48dfafd60 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -340,37 +340,83 @@ class SplitTuples : public IRMutator { class SplitTupleExprs : public IRMutator { using IRMutator::visit; - // TODO: worry about LetStmts that have tuple intrinsics in the RHS - - Stmt visit(const Provide *op) override { - class GetTupleSize : public IRVisitor { - bool permitted = true; - using IRVisitor::visit; - void visit(const Call *op) override { - if (op->is_intrinsic(Call::tuple)) { - user_assert(permitted) - << "Can't nest an expression tuple inside another in definition of " - << op->name << "\n"; - if (result == 0) { - result = (int)op->args.size(); - } else { - user_assert((int)op->args.size() == result) - << "Expression tuples of mismatched sizes used in definition of " - << op->name << ": " << result << " vs " << op->args.size(); - } - // No nesting tuples - permitted = false; - IRVisitor::visit(op); - permitted = true; + class GetTupleSize : public IRVisitor { + bool permitted = true; + using IRVisitor::visit; + void visit(const Call *op) override { + if (op->is_intrinsic(Call::tuple)) { + user_assert(permitted) + << "Can't nest an expression tuple inside another in definition of " + << op->name << "\n"; + if (result == 0) { + result = (int)op->args.size(); } else { - IRVisitor::visit(op); + user_assert((int)op->args.size() == result) + << "Expression tuples of mismatched sizes used in definition of " + << op->name << ": " << result << " vs " << op->args.size(); } + // No nesting tuples + permitted = false; + IRVisitor::visit(op); + permitted = true; + } else { + IRVisitor::visit(op); } + } + + public: + int result = 0; + }; + + class ExtractTupleElement : public IRMutator { + using IRMutator::visit; + Expr visit(const Call *op) override { + if (op->is_intrinsic(Call::tuple)) { + // No need to recursively mutate because we've + // already asserted that these aren't nested. + internal_assert(idx < (int)op->args.size()); + return op->args[idx]; + } else { + return IRMutator::visit(op); + } + } + + public: + int idx; + }; + + Stmt visit(const LetStmt *op) override { + GetTupleSize get_tuple_size; + op->value.accept(&get_tuple_size); + if (get_tuple_size.result == 0) { + return IRMutator::visit(op); + } - public: - int result = 0; - } get_tuple_size; + // Split this variable into the tuple components + ExtractTupleElement extractor; + vector> lets; + vector vars; + for (extractor.idx = 0; extractor.idx < get_tuple_size.result; extractor.idx++) { + string name = unique_name(op->name + "." + std::to_string(extractor.idx)); + lets.emplace_back(name, extractor.mutate(op->value)); + vars.push_back(Variable::make(op->value.type(), name)); + } + + Stmt body = op->body; + Expr tuple_replacement = Call::make(op->value.type(), Call::tuple, vars, Call::PureIntrinsic); + body = substitute(op->name, tuple_replacement, body); + body = mutate(body); + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + body = LetStmt::make(it->first, it->second, body); + } + + return body; + } + + Stmt visit(const Provide *op) override { + GetTupleSize get_tuple_size; op->accept(&get_tuple_size); int size = get_tuple_size.result; @@ -378,28 +424,12 @@ class SplitTupleExprs : public IRMutator { return IRMutator::visit(op); } + ExtractTupleElement extractor; // The LHS should contain at least one tuple, or our scatters // all go to the same place. Is it worth asserting this? It // could be a bug, or it could be some sort of degenerate base case. // Fork the args and the RHS into their various versions - class ExtractTupleElement : public IRMutator { - using IRMutator::visit; - Expr visit(const Call *op) override { - if (op->is_intrinsic(Call::tuple)) { - // No need to recursively mutate because we've - // already asserted that these aren't nested. - internal_assert(idx < (int)op->args.size()); - return op->args[idx]; - } else { - return IRMutator::visit(op); - } - } - - public: - int idx; - } extractor; - vector provides; vector names; vector rhs_values; diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp index 4e6c706fb199..5c939cb3e334 100644 --- a/test/correctness/multiple_scatter.cpp +++ b/test/correctness/multiple_scatter.cpp @@ -180,6 +180,43 @@ int main(int argc, char **argv) { } } + { + // Lexicographic bubble sort on tuples + Func f; + Var x, y; + + f(x) = {13 - (x % 10), cast(x * 17)}; + + RDom r(0, 99, 0, 99); + r.where(r.x < 99 - r.y); + + Expr should_swap = (f(r.x)[0] > f(r.x + 1)[0] || + (f(r.x)[0] == f(r.x + 1)[0] && + f(r.x)[1] > f(r.x + 1)[1])); + r.where(should_swap); + + // This update swaps adjacent pairs of elements whenever the + // second tuple component of the even-indexed element is odd. + + f(make_expr_tuple({r.x, r.x + 1})) = + f(make_expr_tuple({r.x + 1, r.x})); + + Buffer out_0(100); + Buffer out_1(100); + f.realize({out_0, out_1}); + + for (int i = 0; i < 99; i++) { + bool check = (out_0(i) < out_0(i + 1) || + (out_0(i) == out_0(i + 1) && out_1(i) < out_1(i + 1))); + if (!check) { + printf("Sort result is not correctly ordered at elements %d, %d:\n" + "(%d, %d) vs (%d, %d)\n", + i, i + 1, out_0(i), out_1(i), out_0(i + 1), out_1(i + 1)); + return -1; + } + } + } + printf("Success!\n"); return 0; } From 505f975401e0a6a1b9c862e016d2e5ea46cfcdc7 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 12 Dec 2020 15:01:11 -0800 Subject: [PATCH 04/16] Better error messages --- src/SplitTuples.cpp | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 2dc48dfafd60..3979fa0b0690 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -340,6 +340,9 @@ class SplitTuples : public IRMutator { class SplitTupleExprs : public IRMutator { using IRMutator::visit; + // The enclosing producer node. Used for error messages. + const ProducerConsumer *producer = nullptr; + class GetTupleSize : public IRVisitor { bool permitted = true; using IRVisitor::visit; @@ -347,13 +350,13 @@ class SplitTupleExprs : public IRMutator { if (op->is_intrinsic(Call::tuple)) { user_assert(permitted) << "Can't nest an expression tuple inside another in definition of " - << op->name << "\n"; + << producer_name; if (result == 0) { result = (int)op->args.size(); } else { user_assert((int)op->args.size() == result) << "Expression tuples of mismatched sizes used in definition of " - << op->name << ": " << result << " vs " << op->args.size(); + << producer_name << ": " << result << " vs " << op->args.size(); } // No nesting tuples permitted = false; @@ -364,8 +367,17 @@ class SplitTupleExprs : public IRMutator { } } + // Just for error messages. The default value should not + // currently be possible to hit. + string producer_name = "(tuple expression not part of a Func definition)"; + public: int result = 0; + GetTupleSize(const ProducerConsumer *producer) { + if (producer) { + producer_name = producer->name; + } + } }; class ExtractTupleElement : public IRMutator { @@ -385,8 +397,13 @@ class SplitTupleExprs : public IRMutator { int idx; }; + Stmt visit(const ProducerConsumer *op) override { + ScopedValue old(producer, op->is_producer ? op : producer); + return IRMutator::visit(op); + } + Stmt visit(const LetStmt *op) override { - GetTupleSize get_tuple_size; + GetTupleSize get_tuple_size(producer); op->value.accept(&get_tuple_size); if (get_tuple_size.result == 0) { return IRMutator::visit(op); @@ -416,7 +433,7 @@ class SplitTupleExprs : public IRMutator { } Stmt visit(const Provide *op) override { - GetTupleSize get_tuple_size; + GetTupleSize get_tuple_size(producer); op->accept(&get_tuple_size); int size = get_tuple_size.result; From 0c24fce696aa77dddd294ef4c0ec4cce5e94972f Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 12 Dec 2020 15:02:05 -0800 Subject: [PATCH 05/16] Fix comment --- test/correctness/multiple_scatter.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp index 5c939cb3e334..b38102c2a271 100644 --- a/test/correctness/multiple_scatter.cpp +++ b/test/correctness/multiple_scatter.cpp @@ -195,9 +195,7 @@ int main(int argc, char **argv) { f(r.x)[1] > f(r.x + 1)[1])); r.where(should_swap); - // This update swaps adjacent pairs of elements whenever the - // second tuple component of the even-indexed element is odd. - + // Swap elements that satisfy the RDom predicate f(make_expr_tuple({r.x, r.x + 1})) = f(make_expr_tuple({r.x + 1, r.x})); From 75c9373cadf3e23e78ea7edb75b40a3cf17badd4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 13 Dec 2020 10:33:50 -0800 Subject: [PATCH 06/16] Try a new name on for size --- src/Bounds.cpp | 2 +- src/IR.cpp | 2 +- src/IR.h | 2 +- src/SplitTuples.cpp | 34 ++++++++--------- test/correctness/multiple_scatter.cpp | 53 +++++++++++++++++++-------- 5 files changed, 58 insertions(+), 35 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index fdc887e86e21..c9fba2dd594b 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1423,7 +1423,7 @@ class Bounds : public IRVisitor { } else if (op->is_intrinsic(Call::memoize_expr)) { internal_assert(!op->args.empty()); op->args[0].accept(this); - } else if (op->is_intrinsic(Call::tuple)) { + } else if (op->is_intrinsic(Call::scatter_gather)) { // A tuple could evaluate to any one of the args. The base // class visitor is fine as it takes a union. IRVisitor::visit(op); diff --git a/src/IR.cpp b/src/IR.cpp index eb773e1b520b..2b05ba0891cd 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -626,6 +626,7 @@ const char *const intrinsic_op_names[] = { "require_mask", "return_second", "rewrite_buffer", + "scatter_gather", "select_mask", "shift_left", "shift_right", @@ -634,7 +635,6 @@ const char *const intrinsic_op_names[] = { "sorted_avg", "strict_float", "stringify", - "tuple", "undef", "unsafe_promise_clamped", }; diff --git a/src/IR.h b/src/IR.h index 1706a5eff637..a537262dfa5c 100644 --- a/src/IR.h +++ b/src/IR.h @@ -538,6 +538,7 @@ struct Call : public ExprNode { require_mask, return_second, rewrite_buffer, + scatter_gather, select_mask, shift_left, shift_right, @@ -546,7 +547,6 @@ struct Call : public ExprNode { sorted_avg, // Compute (arg[0] + arg[1]) / 2, assuming arg[0] < arg[1]. strict_float, stringify, - tuple, undef, unsafe_promise_clamped, IntrinsicOpCount // Sentinel: keep last. diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 3979fa0b0690..0c3fef6a4f47 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -337,17 +337,17 @@ class SplitTuples : public IRMutator { } }; -class SplitTupleExprs : public IRMutator { +class SplitScatterGather : public IRMutator { using IRMutator::visit; // The enclosing producer node. Used for error messages. const ProducerConsumer *producer = nullptr; - class GetTupleSize : public IRVisitor { + class GetScatterGatherSize : public IRVisitor { bool permitted = true; using IRVisitor::visit; void visit(const Call *op) override { - if (op->is_intrinsic(Call::tuple)) { + if (op->is_intrinsic(Call::scatter_gather)) { user_assert(permitted) << "Can't nest an expression tuple inside another in definition of " << producer_name; @@ -373,17 +373,17 @@ class SplitTupleExprs : public IRMutator { public: int result = 0; - GetTupleSize(const ProducerConsumer *producer) { + GetScatterGatherSize(const ProducerConsumer *producer) { if (producer) { producer_name = producer->name; } } }; - class ExtractTupleElement : public IRMutator { + class ExtractScatterGatherElement : public IRMutator { using IRMutator::visit; Expr visit(const Call *op) override { - if (op->is_intrinsic(Call::tuple)) { + if (op->is_intrinsic(Call::scatter_gather)) { // No need to recursively mutate because we've // already asserted that these aren't nested. internal_assert(idx < (int)op->args.size()); @@ -403,25 +403,25 @@ class SplitTupleExprs : public IRMutator { } Stmt visit(const LetStmt *op) override { - GetTupleSize get_tuple_size(producer); - op->value.accept(&get_tuple_size); - if (get_tuple_size.result == 0) { + GetScatterGatherSize get_scatter_gather_size(producer); + op->value.accept(&get_scatter_gather_size); + if (get_scatter_gather_size.result == 0) { return IRMutator::visit(op); } // Split this variable into the tuple components - ExtractTupleElement extractor; + ExtractScatterGatherElement extractor; vector> lets; vector vars; - for (extractor.idx = 0; extractor.idx < get_tuple_size.result; extractor.idx++) { + for (extractor.idx = 0; extractor.idx < get_scatter_gather_size.result; extractor.idx++) { string name = unique_name(op->name + "." + std::to_string(extractor.idx)); lets.emplace_back(name, extractor.mutate(op->value)); vars.push_back(Variable::make(op->value.type(), name)); } Stmt body = op->body; - Expr tuple_replacement = Call::make(op->value.type(), Call::tuple, vars, Call::PureIntrinsic); + Expr tuple_replacement = Call::make(op->value.type(), Call::scatter_gather, vars, Call::PureIntrinsic); body = substitute(op->name, tuple_replacement, body); body = mutate(body); @@ -433,15 +433,15 @@ class SplitTupleExprs : public IRMutator { } Stmt visit(const Provide *op) override { - GetTupleSize get_tuple_size(producer); - op->accept(&get_tuple_size); - int size = get_tuple_size.result; + GetScatterGatherSize get_scatter_gather_size(producer); + op->accept(&get_scatter_gather_size); + int size = get_scatter_gather_size.result; if (size == 0) { return IRMutator::visit(op); } - ExtractTupleElement extractor; + ExtractScatterGatherElement extractor; // The LHS should contain at least one tuple, or our scatters // all go to the same place. Is it worth asserting this? It // could be a bug, or it could be some sort of degenerate base case. @@ -502,7 +502,7 @@ class SplitTupleExprs : public IRMutator { Stmt split_tuples(const Stmt &stmt, const map &env) { Stmt s = SplitTuples(env).mutate(stmt); - s = SplitTupleExprs().mutate(s); + s = SplitScatterGather().mutate(s); return s; } diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp index b38102c2a271..840928cd5b9e 100644 --- a/test/correctness/multiple_scatter.cpp +++ b/test/correctness/multiple_scatter.cpp @@ -4,8 +4,29 @@ using namespace Halide; using std::vector; -Expr make_expr_tuple(const vector &args) { - return Halide::Internal::Call::make(args[0].type(), Halide::Internal::Call::tuple, args, Halide::Internal::Call::PureIntrinsic); +Expr make_scatter_gather(const vector &args) { + return Halide::Internal::Call::make(args[0].type(), + Halide::Internal::Call::scatter_gather, + args, + Halide::Internal::Call::PureIntrinsic); +} + +template +Expr scatter(Expr e, Args... args) { + return make_scatter_gather({e, args...}); +} + +template +Expr gather(Expr e, Args... args) { + return make_scatter_gather({e, args...}); +} + +Expr scatter(const vector &args) { + return make_scatter_gather(args); +} + +Expr gather(const vector &args) { + return make_scatter_gather(args); } int main(int argc, char **argv) { @@ -45,12 +66,16 @@ int main(int argc, char **argv) { // Run the sorting network with an RDom over the links sorted1(x, y) = input(x, y); RDom r(0, network.dim(1).extent()); + + // We know that the network we'll use caps out at 7, but the + // compiler doesn't know that because it's coming from an input + // buffer, so use unsafe_promise_clamped. Expr min_idx = unsafe_promise_clamped(network(0, r), 0, 7); Expr max_idx = unsafe_promise_clamped(network(1, r), 0, 7); - Expr dst = make_expr_tuple({min_idx, max_idx}); - sorted1(x, dst) = - make_expr_tuple({min(sorted1(x, min_idx), sorted1(x, max_idx)), - max(sorted1(x, min_idx), sorted1(x, max_idx))}); + + sorted1(x, scatter(min_idx, max_idx)) = + gather(min(sorted1(x, min_idx), sorted1(x, max_idx)), + max(sorted1(x, min_idx), sorted1(x, max_idx))); sorted1.vectorize(x, 8).update().vectorize(x, 8); @@ -76,7 +101,7 @@ int main(int argc, char **argv) { rhs[max_idx] = max(tmp, rhs[max_idx]); } - sorted2(x, make_expr_tuple(lhs)) = make_expr_tuple(rhs); + sorted2(x, scatter(lhs)) = gather(rhs); sorted2.vectorize(x, 8).update().vectorize(x, 8); sorted2.realize(output2); @@ -124,8 +149,8 @@ int main(int argc, char **argv) { std::rotate(dst_x.begin(), dst_x.begin() + 1, dst_x.end()); std::rotate(dst_y.begin(), dst_y.begin() + 1, dst_y.end()); - rot(make_expr_tuple(dst_x), make_expr_tuple(dst_y)) = - rot(make_expr_tuple(src_x), make_expr_tuple(src_y)); + rot(scatter(dst_x), scatter(dst_y)) = + rot(gather(src_x), gather(src_y)); Buffer output = rot.realize(sz, sz); @@ -152,11 +177,10 @@ int main(int argc, char **argv) { Func prod; Var x; RDom r(0, input.dim(1).extent()); - Expr lhs = make_expr_tuple({0, 1}); prod(x) = cast(mux(x, {1, 0})); - prod(lhs) = make_expr_tuple( - {prod(0) * input(0, r) - prod(1) * input(1, r), - prod(0) * input(1, r) + prod(1) * input(0, r)}); + prod(scatter(0, 1)) = + gather(prod(0) * input(0, r) - prod(1) * input(1, r), + prod(0) * input(1, r) + prod(1) * input(0, r)); // TODO: We don't currently recognize this as an // associative update, so for now we force it by passing @@ -196,8 +220,7 @@ int main(int argc, char **argv) { r.where(should_swap); // Swap elements that satisfy the RDom predicate - f(make_expr_tuple({r.x, r.x + 1})) = - f(make_expr_tuple({r.x + 1, r.x})); + f(scatter(r.x, r.x + 1)) = f(gather(r.x + 1, r.x)); Buffer out_0(100); Buffer out_1(100); From 2926dc806efec1058e013666acbbeb2e6a80e216 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 10:03:33 -0800 Subject: [PATCH 07/16] Fix spelling mistake --- src/CSE.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CSE.cpp b/src/CSE.cpp index 552cef94d727..d57d1c18cfaa 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -314,7 +314,7 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) { // Wrap the final expr in the lets. for (size_t i = lets.size(); i > 0; i--) { Expr value = lets[i - 1].second; - // Drop this variable as an acceptible replacement for this expr. + // Drop this variable as an acceptable replacement for this expr. replacer.erase(value); // Use containing lets in the value. value = replacer.mutate(lets[i - 1].second); From 66f6f3548b14eb85d8034a32bff9cd151f318560 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 10:26:20 -0800 Subject: [PATCH 08/16] Move scatter/gather helpers to IROperator.h --- src/IROperator.cpp | 19 +++++++ src/IROperator.h | 71 +++++++++++++++++++++++++++ test/correctness/multiple_scatter.cpp | 25 ---------- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/src/IROperator.cpp b/src/IROperator.cpp index 2283a1dbf7e6..2667875242a4 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -2360,4 +2360,23 @@ Expr undef(Type t) { Internal::Call::PureIntrinsic); } +namespace { +Expr make_scatter_gather(const std::vector &args) { + // There's currently no difference in the IR between a gather and + // a scatter. They're distinct just to make code more readable. + return Halide::Internal::Call::make(args[0].type(), + Halide::Internal::Call::scatter_gather, + args, + Halide::Internal::Call::PureIntrinsic); +} +} // namespace + +Expr scatter(const std::vector &args) { + return make_scatter_gather(args); +} + +Expr gather(const std::vector &args) { + return make_scatter_gather(args); +} + } // namespace Halide diff --git a/src/IROperator.h b/src/IROperator.h index fc2dffffcc7a..24cf107243fc 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1398,6 +1398,77 @@ namespace Internal { Expr promise_clamped(const Expr &value, const Expr &min, const Expr &max); } // namespace Internal +/** Scatter and gather are used for update definition which must store + * multiple values to distinct locations at the same time. The + * multiple expressions on the right-hand-side are bundled together + * into a "gather", which must match a "scatter" the the same number + * of arguments on the left-hand-size. For example, to store the + * values 1 and 2 to the locations (x, y, 3) and (x, y, 4), + * respectively: + * +\code +f(x, y, scatter(3, 4)) = gather(1, 2); +\endcode + * + * The result of gather or scatter can be treated as an + * expression. Any containing operations on it can be assumed to + * distribute over the elements. If two gather expressions are + * combined with an arithmetic operator (e.g. added), they combine + * element-wise. The following example stores the values 2 * x, 2 * y, + * and 2 * c to the locations (x + 1, y, c), (x, y + 3, c), and (x, y, + * c + 2) respectively: + * +\code +f(x + scatter(1, 0, 0), y + scatter(0, 3, 0), c + scatter(0, 0, 2)) = 2 * gather(x, y, c); +\endcode +* +* Gathers are most useful for algorithms which require in-place +* swapping or permutation of multiple elements, or other kinds of +* in-place mutations that require loading multiple inputs, doing some +* operations to them jointly, then storing them again. The following +* update definition swaps the values of f at locations 3 and 5 if an +* input parameter p is true: +* +\code +f(scatter(3, 5)) = f(select(p, gather(5, 3), gather(3, 5))); +\endcode +* +* For more examples of the use of scatter and gather, see +* test/correctness/multiple_scatter.cpp +* +* It is not currently possible to use scatter and gather to write an +* update definition in which the *number* of values loaded or stored +* varies, as the size of the scatter/gather packet must be fixed a +* compile-time. A workaround is to make the unwanted extra operations +* a redundant copy of the last operation, which will be +* dead-code-eliminated by the compiler. For example, the following +* update definition swaps the values at locations 3 and 5 when the +* parameter p is true, and rotates the values at locations 1, 2, and 3 +* when it is false. The load from 3 and store to 5 will be redundantly +* repeated: +* +\code +f(select(p, scatter(3, 5, 5), scatter(1, 2, 3))) = f(select(p, gather(5, 3, 3), gather(2, 3, 1))); +\endcode +* +* Note that in the p == true case, we redudantly load from 3 and write +* to 5 twice. +*/ +//@{ +Expr scatter(const std::vector &args); +Expr gather(const std::vector &args); + +template +Expr scatter(Expr e, Args... args) { + return scatter({e, args...}); +} + +template +Expr gather(Expr e, Args... args) { + return gather({e, args...}); +} +// @} + } // namespace Halide #endif diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp index 840928cd5b9e..bd1e8a694469 100644 --- a/test/correctness/multiple_scatter.cpp +++ b/test/correctness/multiple_scatter.cpp @@ -4,31 +4,6 @@ using namespace Halide; using std::vector; -Expr make_scatter_gather(const vector &args) { - return Halide::Internal::Call::make(args[0].type(), - Halide::Internal::Call::scatter_gather, - args, - Halide::Internal::Call::PureIntrinsic); -} - -template -Expr scatter(Expr e, Args... args) { - return make_scatter_gather({e, args...}); -} - -template -Expr gather(Expr e, Args... args) { - return make_scatter_gather({e, args...}); -} - -Expr scatter(const vector &args) { - return make_scatter_gather(args); -} - -Expr gather(const vector &args) { - return make_scatter_gather(args); -} - int main(int argc, char **argv) { // Implement a sorting network using update definitions that write to multiple outputs From b6883f57ae4a4d50f26da49ce0572f2a00ea9ce4 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 10:26:31 -0800 Subject: [PATCH 09/16] Do joint CSE on the LHS too when splitting Tuples --- src/SplitTuples.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 0c3fef6a4f47..5608ee5932e9 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -449,17 +449,20 @@ class SplitScatterGather : public IRMutator { // Fork the args and the RHS into their various versions vector provides; vector names; - vector rhs_values; + vector exprs; for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { vector args = op->args; for (Expr &a : args) { - a = extractor.mutate(a); + string name = unique_name('t'); + exprs.push_back(extractor.mutate(a)); + names.push_back(name); + a = Variable::make(a.type(), name); } vector values = op->values; for (Expr &v : values) { v = extractor.mutate(v); string name = unique_name('t'); - rhs_values.push_back(extractor.mutate(v)); + exprs.push_back(extractor.mutate(v)); names.push_back(name); v = Variable::make(v.type(), name); } @@ -467,10 +470,9 @@ class SplitScatterGather : public IRMutator { } Stmt s = Block::make(provides); - - // We just duplicated all the non-tuple stuff on the RHS too, - // so do joint CSE on the rhs_values - Expr bundle = Call::make(Int(32), Call::bundle, rhs_values, Call::PureIntrinsic); + // We just duplicated all the non-tuple stuff too, + // so do joint CSE on the exprs + Expr bundle = Call::make(Int(32), Call::bundle, exprs, Call::PureIntrinsic); bundle = common_subexpression_elimination(bundle); vector> lets; @@ -480,7 +482,7 @@ class SplitScatterGather : public IRMutator { } const Call *c = bundle.as(); internal_assert(c && c->is_intrinsic(Call::bundle)); - for (size_t i = 0; i < rhs_values.size(); i++) { + for (size_t i = 0; i < exprs.size(); i++) { if (is_pure(c->args[i])) { // names[i] is only used once, so if the value is pure // it should be substituted in From 5203a9a047151270ff9b61a4df874e8874a59c79 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 10:27:21 -0800 Subject: [PATCH 10/16] Add useful simplifications for the code generated by scatter/gather pairs --- src/Simplify_And.cpp | 1 + src/Simplify_Stmts.cpp | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/Simplify_And.cpp b/src/Simplify_And.cpp index 1b35fda68340..31975da9b9cb 100644 --- a/src/Simplify_And.cpp +++ b/src/Simplify_And.cpp @@ -57,6 +57,7 @@ Expr Simplify::visit(const And *op, ExprInfo *bounds) { rewrite(!x && x, false) || rewrite(y <= x && x < y, false) || rewrite(x != c0 && x == c1, b, c0 != c1) || + rewrite(x == c0 && x == c1, false, c0 != c1) || // Note: In the predicate below, if undefined overflow // occurs, the predicate counts as false. If well-defined // overflow occurs, the condition couldn't possibly diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 7e4b40d2905b..4f8fe8ffcd91 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -388,6 +388,9 @@ Stmt Simplify::visit(const Block *op) { rest.as() ? rest.as() : (block_rest ? block_rest->first.as() : nullptr); Stmt if_rest = block_rest ? block_rest->rest : Stmt(); + const Store *store_first = first.as(); + const Store *store_next = block_rest ? block_rest->first.as() : rest.as(); + if (is_no_op(first) && is_no_op(rest)) { return Evaluate::make(0); @@ -411,11 +414,24 @@ Stmt Simplify::visit(const Block *op) { new_block = substitute(let_rest->name, new_var, new_block); return LetStmt::make(var_name, let_first->value, new_block); + } else if (store_first && + store_next && + equal(store_first->index, store_next->index) && + equal(store_first->predicate, store_next->predicate) && + is_pure(store_first->index) && + is_pure(store_first->value) && + is_pure(store_first->predicate)) { + // Second store clobbers first + if (block_rest) { + return Block::make(store_next, block_rest->rest); + } else { + return store_next; + } } else if (if_first && if_next && equal(if_first->condition, if_next->condition) && is_pure(if_first->condition)) { - // Two ifs with matching conditions + // Two ifs with matching conditions. Stmt then_case = mutate(Block::make(if_first->then_case, if_next->then_case)); Stmt else_case; if (if_first->else_case.defined() && if_next->else_case.defined()) { @@ -448,6 +464,25 @@ Stmt Simplify::visit(const Block *op) { result = Block::make(result, if_rest); } return result; + } else if (if_first && + if_next && + is_pure(if_first->condition) && + is_pure(if_next->condition) && + is_const_one(mutate(!(if_first->condition && if_next->condition), nullptr))) { + // Two ifs where the first condition being true implies the + // second is false. The second if can be nested inside the + // else case of the first one, turning a block of if + // statements into an if-else chain. + Stmt then_case = if_first->then_case; + Stmt else_case = if_next; + if (if_first->else_case.defined()) { + else_case = Block::make(if_first->else_case, else_case); + } + Stmt result = mutate(IfThenElse::make(if_first->condition, then_case, else_case)); + if (if_rest.defined()) { + result = Block::make(result, if_rest); + } + return result; } else if (op->first.same_as(first) && op->rest.same_as(rest)) { return op; From 0597188165464943709e14e5f74f0d0949671178 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 10:46:55 -0800 Subject: [PATCH 11/16] Use less stack space. --- src/SplitTuples.cpp | 86 ++++++++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/src/SplitTuples.cpp b/src/SplitTuples.cpp index 5608ee5932e9..cc3745b41579 100644 --- a/src/SplitTuples.cpp +++ b/src/SplitTuples.cpp @@ -367,18 +367,25 @@ class SplitScatterGather : public IRMutator { } } - // Just for error messages. The default value should not - // currently be possible to hit. - string producer_name = "(tuple expression not part of a Func definition)"; - public: + string producer_name; int result = 0; - GetScatterGatherSize(const ProducerConsumer *producer) { - if (producer) { - producer_name = producer->name; - } + } get_scatter_gather_size_visitor; + + int get_scatter_gather_size(const IRNode *op) { + if (producer) { + get_scatter_gather_size_visitor.producer_name = producer->name; } - }; + get_scatter_gather_size_visitor.result = 0; + op->accept(&get_scatter_gather_size_visitor); + + // Maybe the user did something like jam a gather op into a + // constraint or tile size, so this is a user_assert. + user_assert(producer || get_scatter_gather_size_visitor.result == 0) + << "scatter/gather expression used outside of a Func definition"; + + return get_scatter_gather_size_visitor.result; + } class ExtractScatterGatherElement : public IRMutator { using IRMutator::visit; @@ -395,34 +402,29 @@ class SplitScatterGather : public IRMutator { public: int idx; - }; + } extractor; Stmt visit(const ProducerConsumer *op) override { ScopedValue old(producer, op->is_producer ? op : producer); return IRMutator::visit(op); } - Stmt visit(const LetStmt *op) override { - GetScatterGatherSize get_scatter_gather_size(producer); - op->value.accept(&get_scatter_gather_size); - if (get_scatter_gather_size.result == 0) { - return IRMutator::visit(op); - } - - // Split this variable into the tuple components - ExtractScatterGatherElement extractor; - + Stmt visit_gather_let_stmt(const LetStmt *op, int size) { + // Split this variable into the gather components vector> lets; vector vars; - for (extractor.idx = 0; extractor.idx < get_scatter_gather_size.result; extractor.idx++) { + for (extractor.idx = 0; extractor.idx < size; extractor.idx++) { string name = unique_name(op->name + "." + std::to_string(extractor.idx)); lets.emplace_back(name, extractor.mutate(op->value)); vars.push_back(Variable::make(op->value.type(), name)); } Stmt body = op->body; - Expr tuple_replacement = Call::make(op->value.type(), Call::scatter_gather, vars, Call::PureIntrinsic); - body = substitute(op->name, tuple_replacement, body); + Expr gather_replacement = Call::make(op->value.type(), + Call::scatter_gather, + vars, + Call::PureIntrinsic); + body = substitute(op->name, gather_replacement, body); body = mutate(body); for (auto it = lets.rbegin(); it != lets.rend(); it++) { @@ -432,17 +434,43 @@ class SplitScatterGather : public IRMutator { return body; } + Stmt visit(const LetStmt *op) override { + vector> lets; + int size = 0; + Stmt body; + do { + body = op->body; + size = get_scatter_gather_size(op->value.get()); + if (size != 0) { + break; + } + lets.emplace_back(op->name, op->value); + op = body.as(); + } while (op); + + if (size) { + internal_assert(op); + body = visit_gather_let_stmt(op, size); + } else { + internal_assert(op == nullptr); + body = mutate(body); + } + + for (auto it = lets.rbegin(); it != lets.rend(); it++) { + body = LetStmt::make(it->first, it->second, body); + } + + return body; + } + Stmt visit(const Provide *op) override { - GetScatterGatherSize get_scatter_gather_size(producer); - op->accept(&get_scatter_gather_size); - int size = get_scatter_gather_size.result; + int size = get_scatter_gather_size(op); if (size == 0) { return IRMutator::visit(op); } - ExtractScatterGatherElement extractor; - // The LHS should contain at least one tuple, or our scatters + // The LHS should contain at least one scatter op, or our scatters // all go to the same place. Is it worth asserting this? It // could be a bug, or it could be some sort of degenerate base case. @@ -470,7 +498,7 @@ class SplitScatterGather : public IRMutator { } Stmt s = Block::make(provides); - // We just duplicated all the non-tuple stuff too, + // We just duplicated all the non-scatter/gather stuff too, // so do joint CSE on the exprs Expr bundle = Call::make(Int(32), Call::bundle, exprs, Call::PureIntrinsic); bundle = common_subexpression_elimination(bundle); From 77873fc15400acb903fe8082d931e7559c36d102 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 20 Dec 2020 12:17:57 -0800 Subject: [PATCH 12/16] Appease clang-tidy --- src/IROperator.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/IROperator.h b/src/IROperator.h index 24cf107243fc..a271db904575 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1459,13 +1459,13 @@ Expr scatter(const std::vector &args); Expr gather(const std::vector &args); template -Expr scatter(Expr e, Args... args) { - return scatter({e, args...}); +Expr scatter(const Expr &e, Args &&... args) { + return scatter({e, std::forward(args)...}); } template -Expr gather(Expr e, Args... args) { - return gather({e, args...}); +Expr gather(const Expr &e, Args &&... args) { + return gather({e, std::forward(args)...}); } // @} From 82d64d2b72d6ffb60e55b15c5b949546be68a2e1 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 27 Dec 2020 16:06:04 -0800 Subject: [PATCH 13/16] Somewhat important condition missing from dead store elimination The first store should be to the same buffer as the second --- src/Simplify_Stmts.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 4f8fe8ffcd91..1fe1e418bbb7 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -416,6 +416,7 @@ Stmt Simplify::visit(const Block *op) { return LetStmt::make(var_name, let_first->value, new_block); } else if (store_first && store_next && + store_first->name == store_next->name && equal(store_first->index, store_next->index) && equal(store_first->predicate, store_next->predicate) && is_pure(store_first->index) && From 4694ec0f370d2561c05b7881dc32649393b6427b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sun, 27 Dec 2020 16:17:49 -0800 Subject: [PATCH 14/16] Another fix for dead store elimination --- src/Simplify_Stmts.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 1fe1e418bbb7..117d3112ec19 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -1,5 +1,6 @@ #include "Simplify_Internal.h" +#include "ExprUsesVar.h" #include "IRMutator.h" #include "Substitute.h" @@ -421,7 +422,10 @@ Stmt Simplify::visit(const Block *op) { equal(store_first->predicate, store_next->predicate) && is_pure(store_first->index) && is_pure(store_first->value) && - is_pure(store_first->predicate)) { + is_pure(store_first->predicate) && + !expr_uses_var(store_next->index, store_next->name) && + !expr_uses_var(store_next->value, store_next->name) && + !expr_uses_var(store_next->predicate, store_next->name)) { // Second store clobbers first if (block_rest) { return Block::make(store_next, block_rest->rest); From 04af6b9787a6a05eb389d2a8e6cac2a7188991f6 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 29 Dec 2020 20:47:04 -0800 Subject: [PATCH 15/16] The base class visitor does not take the union --- src/Bounds.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 5d5f5a62aa5b..ce42e91f1040 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1426,9 +1426,13 @@ class Bounds : public IRVisitor { internal_assert(!op->args.empty()); op->args[0].accept(this); } else if (op->is_intrinsic(Call::scatter_gather)) { - // A tuple could evaluate to any one of the args. The base - // class visitor is fine as it takes a union. - IRVisitor::visit(op); + // Take the union of the args + Interval result = Interval::nothing(); + for (const Expr &e : op->args) { + e.accept(this); + result.include(interval); + } + interval = result; } else if (op->call_type == Call::Halide) { bounds_of_func(op->name, op->value_index, op->type); } else { From 5b06a146baf69df1cf2c7dc20a3815e130d1fd63 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 30 Dec 2020 16:49:16 -0800 Subject: [PATCH 16/16] Address review comments --- src/IROperator.h | 8 +++++ test/correctness/multiple_scatter.cpp | 42 +++++++++++++++++++++++---- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/src/IROperator.h b/src/IROperator.h index a271db904575..be39fb321d8c 100644 --- a/src/IROperator.h +++ b/src/IROperator.h @@ -1422,6 +1422,14 @@ f(x, y, scatter(3, 4)) = gather(1, 2); f(x + scatter(1, 0, 0), y + scatter(0, 3, 0), c + scatter(0, 0, 2)) = 2 * gather(x, y, c); \endcode * +* Repeated values in the scatter cause multiple stores to the same +* location. The stores happen in order from left to right, so the +* rightmost value wins. The following code is equivalent to f(x) = 5 +* +\code +f(scatter(x, x)) = gather(3, 5); +\endcode +* * Gathers are most useful for algorithms which require in-place * swapping or permutation of multiple elements, or other kinds of * in-place mutations that require loading multiple inputs, doing some diff --git a/test/correctness/multiple_scatter.cpp b/test/correctness/multiple_scatter.cpp index bd1e8a694469..e05ae75905a9 100644 --- a/test/correctness/multiple_scatter.cpp +++ b/test/correctness/multiple_scatter.cpp @@ -112,11 +112,6 @@ int main(int argc, char **argv) { RDom r(0, (sz + 1) / 2, 0, sz / 2); - vector in{rot(r.x, r.y), - rot(sz - 1 - r.y, r.x), - rot(sz - 1 - r.x, sz - 1 - r.y), - rot(r.y, sz - 1 - r.x)}; - vector src_x{r.x, sz - 1 - r.y, sz - 1 - r.x, r.y}; vector src_y{r.y, r.x, sz - 1 - r.y, sz - 1 - r.x}; vector dst_x = src_x, dst_y = src_y; @@ -213,6 +208,43 @@ int main(int argc, char **argv) { } } + { + // A scatter can exist without a gather if you're just broadcasting + Func f; + Var x; + f(x) = 0; + f(scatter(0, 1, 2, 3)) = 5; + + Buffer out = f.realize(5); + for (int i = 0; i < 5; i++) { + int correct = i < 4 ? 5 : 0; + if (out(i) != correct) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct); + return -1; + } + } + } + + { + // A gather can exist without a scatter, but it's sort of + // silly because last element wins. It's not outright + // disallowed because it may be a degenerate case of some + // generic code. + Func f; + Var x; + f(x) = 0; + f(3) = gather(1, 9); + + Buffer out = f.realize(5); + for (int i = 0; i < 5; i++) { + int correct = i == 3 ? 9 : 0; + if (out(i) != correct) { + printf("out(%d) = %d instead of %d\n", i, out(i), correct); + return -1; + } + } + } + printf("Success!\n"); return 0; }