diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 97444bdb61e0..5f57ccc941e0 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -168,7 +168,8 @@ Definition of a scope that is a stage pipeline: if (it_atomic != block->annotations.end()) { is_atomic = ((*it_atomic).second).as()->value; } - if (!is_atomic) { + // Todo(ruihang): Temporary hack. Deal with the "sparse" annotation later. + if (!is_atomic && block->annotations.find("sparse") == block->annotations.end()) { throw NotCompactDataFlowError(self->mod, GetRef(scope_root_subtree->stmt), GetRef(block)); } diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 5e28b1974a06..51fd31c241b4 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -322,10 +322,6 @@ class SparseBufferCtx { matches_.emplace_back(axis->name == sp_iter_var->axis->name); } } - - // update offset - PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), ana_); - offsets_.emplace_back(std::move(new_offset)); } /*! \brief get the axis given dimension index of current buffer. */ @@ -341,7 +337,7 @@ class SparseBufferCtx { AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), ana_)}; } - private: + public: String buf_name_; Array axes_; std::vector offsets_; @@ -375,7 +371,12 @@ class SparseBufferCtx { top()->Register(dim, std::move(coordinate), std::move(orig_idx)); } - private: + void AddOffset(int dim, PrimExpr offset) { + ICHECK_EQ(dim + 1, static_cast(top()->offsets_.size())); + top()->offsets_.push_back(offset); + } + + public: std::vector stack_; arith::Analyzer* ana_; @@ -421,18 +422,22 @@ class IndexTransformer : public StmtExprMutator { auto sf_axis = axis.as(); PrimExpr l, r; std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim); - offset = lower_bound(sf_axis->indices->data, coordinate, l, r); + offset = lower_bound(sf_axis->indices->data, coordinate, l, r) - l; break; } case AxisKind::kSparseVariable: auto sv_axis = axis.as(); PrimExpr l, r; std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim); - offset = lower_bound(sv_axis->indices->data, coordinate, l, r); + offset = lower_bound(sv_axis->indices->data, coordinate, l, r) - l; break; } } + // update offset + PrimExpr new_offset = AggregateOffset(sp_buf_ctx_.top()->offsets_.back(), axis, + offset, sp_buf_ctx_.ana_); + sp_buf_ctx_.top()->offsets_.push_back(std::move(new_offset)); return offset; } @@ -562,7 +567,8 @@ class IndexTransformer : public StmtExprMutator { Axis axis = sp_it_var->axis; auto parent = axis->GetParentAxis(); bool create_new_blk = false; - bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed; + bool is_fixed_axis = + axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed; if (!is_fixed_axis && parent.defined()) { const AxisNode* parent_node = parent.value().get(); if (in_block.find(parent_node) != in_block.end()) { @@ -572,7 +578,8 @@ class IndexTransformer : public StmtExprMutator { /* parent node is in the previous blocks in the stack, no need to create new block. */ create_new_blk = false; } else { - CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block."; + CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " + << axis->GetName() << " when defining a sparse block."; } } if (create_new_blk) { diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index de2a1c1b49ab..eb538f961351 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -19,6 +19,7 @@ import tvm.tir as tir import scipy.sparse as sp import numpy as np +import pytest from tvm.script import tir as T @@ -367,7 +368,7 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") - + for v_vi in T.serial(0, M): with T.block("square_sum_2"): vi = T.axis.spatial(M, v_vi) @@ -391,6 +392,58 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] +@T.prim_func +def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + # Used only for testing `GetIndicesRange()`. + # Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the + # same as `indices_k1`. + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32") + K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32") + A = T.match_sparse_buffer(a, (I, J, K0), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32") + K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32") + K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32") + K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj]): + with T.block("square_sum"): + vk = T.axis.reduce(K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj], v_vk) + T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]]) + T.writes([B_data[0 : M]]) + T.block_attr({"sparse":True}) + B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound(K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")] + + def test_csrmm(): mod = tvm.IRModule.from_expr(csrmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -414,6 +467,7 @@ def test_csrmm(): tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) +@pytest.mark.skip(reason="Under implementation") def test_csrmm_dense_iter(): mod = tvm.IRModule.from_expr(csrmm_dense_iter) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -421,6 +475,7 @@ def test_csrmm_dense_iter(): # Todo +@pytest.mark.skip(reason="Under implementation") def test_segment_reduce(): mod = tvm.IRModule.from_expr(segment_reduce) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -557,6 +612,7 @@ def test_csr_element_wise(): tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) +@pytest.mark.skip(reason="Under implementation") def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) @@ -600,6 +656,49 @@ def test_square_sum(): tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) +def test_square_sum_two_K(): + mod = tvm.IRModule.from_expr(square_sum_two_K) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True) + + sch = tir.Schedule(mod, debug_mask="all") + i, = sch.get_loops(sch.get_block("square_sum_2")) + sch.bind(i, "threadIdx.x") + + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum_two_K.params[-5:] + f = tvm.build(sch.mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda") + + ctx = tvm.device("cuda") + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) + + if __name__ == "__main__": test_csrmm() test_csrmm_dense_iter() @@ -610,3 +709,4 @@ def test_square_sum(): test_csr_element_wise() test_bmm() test_square_sum() + test_square_sum_two_K()