Skip to content

Commit

Permalink
Remove buggy deinterleave misfeature (#5844)
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams committed Mar 24, 2021
1 parent 92dfc82 commit 9a8ddf7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 87 deletions.
57 changes: 4 additions & 53 deletions src/Deinterleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,14 +633,14 @@ class Interleaver : public IRMutator {

const int64_t *stride_ptr = as_const_int(r0->stride);

// The stride isn't a constant or is <= 0
if (!stride_ptr || *stride_ptr < 1) {
// The stride isn't a constant or is <= 1
if (!stride_ptr || *stride_ptr <= 1) {
return Stmt();
}

const int64_t stride = *stride_ptr;
const int lanes = r0->lanes;
const int64_t expected_stores = stride == 1 ? lanes : stride;
const int64_t expected_stores = stride;

// Collect the rest of the stores.
std::vector<Stmt> stores;
Expand Down Expand Up @@ -690,53 +690,11 @@ class Interleaver : public IRMutator {
if (*offs < min_offset) {
min_offset = *offs;
}

if (stride == 1) {
// Difference between bases is not a multiple of the lanes.
if (*offs % lanes != 0) {
return Stmt();
}

// This case only triggers if we have an immediate load of the correct stride on the RHS.
// TODO: Could we consider mutating the RHS so that we can handle more complex Expr's than just loads?
const Load *load = stores[i].as<Store>()->value.as<Load>();
if (!load) {
return Stmt();
}
// TODO(psuriana): Predicated load is not currently handled.
if (!is_const_one(load->predicate)) {
return Stmt();
}

const Ramp *ramp = load->index.as<Ramp>();
if (!ramp) {
return Stmt();
}

// Load stride or lanes is not equal to the store lanes.
if (!is_const(ramp->stride, lanes) || ramp->lanes != lanes) {
return Stmt();
}

if (i == 0) {
load_name = load->name;
load_image = load->image;
load_param = load->param;
} else {
if (load->name != load_name) {
return Stmt();
}
}
}
}

// Gather the args for interleaving.
for (size_t i = 0; i < stores.size(); ++i) {
int j = offsets[i] - min_offset;
if (stride == 1) {
j /= stores.size();
}

if (j == 0) {
base = stores[i].as<Store>()->index.as<Ramp>()->base;
}
Expand All @@ -751,14 +709,7 @@ class Interleaver : public IRMutator {
return Stmt();
}

if (stride == 1) {
// Convert multiple dense vector stores of strided vector loads
// into one dense vector store of interleaving dense vector loads.
args[j] = Load::make(t, load_name, stores[i].as<Store>()->index,
load_image, load_param, const_true(t.lanes()), ModulusRemainder());
} else {
args[j] = stores[i].as<Store>()->value;
}
args[j] = stores[i].as<Store>()->value;
predicates[j] = stores[i].as<Store>()->predicate;
}

Expand Down
41 changes: 7 additions & 34 deletions test/correctness/interleave.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,43 +353,24 @@ int main(int argc, char **argv) {
}

{
// Test that transposition works when vectorizing either dimension:
// Test transposition
Func square("square");
square(x, y) = cast(UInt(16), 5 * x + y);

Func trans1("trans1");
trans1(x, y) = square(y, x);

Func trans2("trans2");
trans2(x, y) = square(y, x);
Func trans("trans2");
trans(x, y) = square(y, x);

square.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8);

trans1.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8)
.vectorize(x)
.unroll(y);

trans2.compute_root()
trans.compute_root()
.bound(x, 0, 8)
.bound(y, 0, 8)
.unroll(x)
.vectorize(y);

trans1.output_buffer()
.dim(0)
.set_min(0)
.set_stride(1)
.set_extent(8)
.dim(1)
.set_min(0)
.set_stride(8)
.set_extent(8);

trans2.output_buffer()
trans.output_buffer()
.dim(0)
.set_min(0)
.set_stride(1)
Expand All @@ -399,28 +380,20 @@ int main(int argc, char **argv) {
.set_stride(8)
.set_extent(8);

Buffer<uint16_t> result6(8, 8);
Buffer<uint16_t> result7(8, 8);
trans1.realize(result6);
trans2.realize(result7);
trans.realize(result7);

for (int x = 0; x < 8; x++) {
for (int y = 0; y < 8; y++) {
int correct = 5 * y + x;
if (result6(x, y) != correct) {
printf("result(%d) = %d instead of %d\n", x, result6(x, y), correct);
return -1;
}

if (result7(x, y) != correct) {
printf("result(%d) = %d instead of %d\n", x, result7(x, y), correct);
return -1;
}
}
}

check_interleave_count(trans1, 1);
check_interleave_count(trans2, 1);
check_interleave_count(trans, 1);
}

printf("Success!\n");
Expand Down

0 comments on commit 9a8ddf7

Please sign in to comment.