Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix offset caching in lowering #38

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ Definition of a scope that is a stage pipeline:
if (it_atomic != block->annotations.end()) {
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
}
if (!is_atomic) {
// Todo(ruihang): Temporary hack. Deal with the "sparse" annotation later.
if (!is_atomic && block->annotations.find("sparse") == block->annotations.end()) {
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
Expand Down
27 changes: 17 additions & 10 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,6 @@ class SparseBufferCtx {
matches_.emplace_back(axis->name == sp_iter_var->axis->name);
}
}

// update offset
PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), ana_);
offsets_.emplace_back(std::move(new_offset));
}

/*! \brief get the axis given dimension index of current buffer. */
Expand All @@ -341,7 +337,7 @@ class SparseBufferCtx {
AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), ana_)};
}

private:
public:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why make it public?

String buf_name_;
Array<Axis> axes_;
std::vector<PrimExpr> offsets_;
Expand Down Expand Up @@ -375,7 +371,12 @@ class SparseBufferCtx {
top()->Register(dim, std::move(coordinate), std::move(orig_idx));
}

private:
void AddOffset(int dim, PrimExpr offset) {
ICHECK_EQ(dim + 1, static_cast<int>(top()->offsets_.size()));
top()->offsets_.push_back(offset);
}

public:
std::vector<Scope> stack_;
arith::Analyzer* ana_;

Expand Down Expand Up @@ -421,18 +422,22 @@ class IndexTransformer : public StmtExprMutator {
auto sf_axis = axis.as<SparseFixedAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sf_axis->indices->data, coordinate, l, r);
offset = lower_bound(sf_axis->indices->data, coordinate, l, r) - l;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

break;
}
case AxisKind::kSparseVariable:
auto sv_axis = axis.as<SparseVariableAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sv_axis->indices->data, coordinate, l, r);
offset = lower_bound(sv_axis->indices->data, coordinate, l, r) - l;
break;
}
}

// update offset
PrimExpr new_offset = AggregateOffset(sp_buf_ctx_.top()->offsets_.back(), axis,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we expose a function to manipulate the offset?
I suggest not operating on top() directly.

offset, sp_buf_ctx_.ana_);
sp_buf_ctx_.top()->offsets_.push_back(std::move(new_offset));
return offset;
}

Expand Down Expand Up @@ -562,7 +567,8 @@ class IndexTransformer : public StmtExprMutator {
Axis axis = sp_it_var->axis;
auto parent = axis->GetParentAxis();
bool create_new_blk = false;
bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
bool is_fixed_axis =
axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
if (!is_fixed_axis && parent.defined()) {
const AxisNode* parent_node = parent.value().get();
if (in_block.find(parent_node) != in_block.end()) {
Expand All @@ -572,7 +578,8 @@ class IndexTransformer : public StmtExprMutator {
/* parent node is in the previous blocks in the stack, no need to create new block. */
create_new_blk = false;
} else {
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block.";
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before "
<< axis->GetName() << " when defining a sparse block.";
}
}
if (create_new_blk) {
Expand Down
102 changes: 101 additions & 1 deletion tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
import pytest
from tvm.script import tir as T


Expand Down Expand Up @@ -367,7 +368,7 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32")

for v_vi in T.serial(0, M):
with T.block("square_sum_2"):
vi = T.axis.spatial(M, v_vi)
Expand All @@ -391,6 +392,58 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk]


@T.prim_func
def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32):
# Used only for testing `GetIndicesRange()`.
# Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the
# same as `indices_k1`.
I = T.dense_fixed(M)
J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32")
K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32")
K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32")
A = T.match_sparse_buffer(a, (I, J, K0), "float32")
B = T.match_sparse_buffer(b, (I,), "float32")

with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]:
with T.init():
B[vi] = 0.0
B[vi] = B[vi] + A[vi, vj, vk]


@T.prim_func
def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None:
A_data = T.match_buffer(a, [nnz_k], dtype="float32")
B_data = T.match_buffer(b, [M], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32")
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32")
K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32")
K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32")
K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32")

for v_vi in T.serial(0, M):
with T.block("square_sum_2"):
vi = T.axis.spatial(M, v_vi)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
with T.block("square_sum_1"):
vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
with T.init():
B_data[vi] = T.float32(0)
for v_vk in T.serial(0, K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj]):
with T.block("square_sum"):
vk = T.axis.reduce(K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj], v_vk)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound(K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")]


def test_csrmm():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand All @@ -414,13 +467,15 @@ def test_csrmm():
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5)


@pytest.mark.skip(reason="Under implementation")
def test_csrmm_dense_iter():
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
# Todo


@pytest.mark.skip(reason="Under implementation")
def test_segment_reduce():
mod = tvm.IRModule.from_expr(segment_reduce)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand Down Expand Up @@ -557,6 +612,7 @@ def test_csr_element_wise():
tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5)


@pytest.mark.skip(reason="Under implementation")
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand Down Expand Up @@ -600,6 +656,49 @@ def test_square_sum():
tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)


def test_square_sum_two_K():
mod = tvm.IRModule.from_expr(square_sum_two_K)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True)

sch = tir.Schedule(mod, debug_mask="all")
i, = sch.get_loops(sch.get_block("square_sum_2"))
sch.bind(i, "threadIdx.x")

density = 0.0125
M = N1 = N2 = 128
A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr")
indptr_j = A_J.indptr
indices_j = A_J.indices
nnz_j = A_J.nnz
A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr")
indptr_k = A_K.indptr
indices_k = A_K.indices
nnz_k = A_K.nnz
data = A_K.data

b_ij = np.asarray(A_K.sum(axis=1)).squeeze()
A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1))
b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze()
b = np.zeros((M,)).astype("float32")

v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum_two_K.params[-5:]
f = tvm.build(sch.mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda")

ctx = tvm.device("cuda")
A_data = tvm.nd.array(data.astype("float32"), device=ctx)
A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx)
A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx)
A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
B_data = tvm.nd.array(b.astype("float32"), device=ctx)
f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1)

tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
test_csrmm()
test_csrmm_dense_iter()
Expand All @@ -610,3 +709,4 @@ def test_square_sum():
test_csr_element_wise()
test_bmm()
test_square_sum()
test_square_sum_two_K()