From 7540891b4798c59dab9c382920709365ce2e7d40 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 20 Oct 2021 06:51:42 -0700 Subject: [PATCH] Format and Buffer data structure (#1) [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye Fix AxisTree (#3) * fix axis tree * upd [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` [SparseTIR] Introduce SpIterVar (#6) * [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr [BugFix] Fix binary search & SpIterVar (#7) [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting [SparseTIR] Index Lowering (#8) * Add StmtFunctor/ExprFunctor for SparseBufferStore/Load * Add basic index lowering * Finish index lowering (maybe) * Address comments * Convert CRLF to LF Frontend update, demo scripts. (#10) * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye * Fix AxisTree (#3) * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * fix axis tree * upd * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * Format and Buffer data structure (#1) * [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (#2) * add methods for Object * axis constructors * methods for SparseBuffer * put into registry * python interface * [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (#483) (#4) * upd * upd * fix * upd * upd * upd * upd * upd * fix * upd * upd * upd * upd * upd * upd * upd * codegen-rule * upd * upd * test * upd * fix * two arguments Co-authored-by: Zihao Ye * Fix AxisTree (#3) * fix axis tree * upd * [SparseTIR] Add SparseBufferLoad/SparseBufferStore (#5) * Add dtype for SparseBuffer * Add name for SparseBuffer. Remove `ndim` * Remove namespace sparse * Add SparseBufferLoad/Store * Add method `ndim()` * [SparseTIR] Introduce SpIterVar (#6) * [SparseTIR] Introduce SpIterVar * Add conversion to PrimExpr * [BugFix] Fix binary search & SpIterVar (#7) * [BugFix] Add field `is_reduction` for SpIterVar (#9) * [BugFix] Add field `is_reduction` for SpIterVar * Formatting * upd * upd Co-authored-by: Ruihang Lai [SparseTIR] SparseBlock on C++/Python side (#11) * Fix a bug in the last commit * SparseBlock on C++ & Python side [BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (#12) * Update `cord` and `pos` * Fix `idtype` * Formatting.. * Bug fix 1 * Move new special stmts * Parser for Axis and SpIterVar * Fix context_maintainer.py [SparseTIR] Enhance SparseBlock to contain enough PrimFunc information (#13) * Enhance SparseBlock to have enough PrimFunc info * Remove `func_sparse_buffer_map_` * Don't print the map uh-huh [SparseTIR] Parser, Printer, Roundtrip (#14) * SparseBlock scope handler (part 1) * SparseBlock scope handler (part 2) * SparseBlock scope handler (part 3) * SparseBlock scope handler (fix 1) * Add SparseBufferLoad/Store on Python side * Parser for SparseBufferLoad/Store * Add SparseBlock to Python __init__ * StmtFunctor for SparseBlock * Ensure at least one dimension for SparseBuffer * Make `axis` field of SpIterVar mandatory * SparseBlock scope handler (fix 2) * Update Axis syntax by removing `name` parameter * Move to intrin.py * Add filed `from_sparse` to DenseFixedAxis * SparseTIR script printer * Roundtrip test * `update_symbol` bug fix * Fix attr visit in SparseBuffer * Define then compare in SparseBlock * Fix printer bug for SparseBuffer * Enable graph match for Axis and SparseBuffer * Complete HashReduce and EqualReduce for AxisTree and SparseBuffer * Fix typo * Rename test * Bug fix 1 * Bug fix 2 * Add more tests Move tests (#15) [SparseTIR] ReprPrinter for Axis and SpIterVar (#16) upd (#17) flatten (#18) ELL and BSR correctness test scripts (#19) [SparseTIR] SparseTIR Lowering (#20) * Fix a previous bug of sparse-fixed SpIterVar creation * Fix a previous bug in `GetDenseValue` * Refactor Collector and IndexTransformer * Construct block and loops * Fix a previous bug which rejects DV iters in collector * Update buffer map * Create root block * Fix bug of sparse-fixed SpIterVar creation * Fix bug on SpIterVar conversion (with refactor) * Fix bug when getting dependent SpIterVars * Fix bug on dependency map and index lowering * Full block read/write region * Test version 1 * Fix bug of loop order * Fix bug of batch-mm iterator ordering * Update PrimFunc args to use symbolic params * Fix bug of test "csr_element_wise" * Fix bug of index accumulation for sparse-fixed axis * Update correctness test * Test structural equality * Refactor and use Array fix nnz cols Add docstring for sparse tir lowering (#21) * add docstring * upd Add more examples part 1 (sddmm) (#22) * upd * upd * upd [SparseTIR][Schedule] SparseBlockRV, GetSparseBlock, SparseReorder (#23) * Test initialization * Fix a stupid bug of ReprPrinter * Add SparseBlockRV * Schedule: GetSparseBlock * Schedule: Reorder [SparseTIR][Schedule] GetSpIters (#24) remove hybrid script for successful compilation Add atomic intrinsic for output nonzero inference. (#25) * upd * upd Add "sparse" block attribute. (#26) Revert "remove hybrid script for successful compilation" This reverts commit eebd7c18d188db5e335945e41cd024d24c8e1c54. [SparseTIR] Hack `IsAffineBinding` check (#27) * [TensorIR][Schedule] Inherit block anotation upon creating new blocks * Fix SDDMM test * Hack IsAffineBinding for sparse blocks Axis Dependency Tree aware code-gen and bmm example (#28) * upd * upd * upd * upd * upd * upd * upd * upd * remove redundancy * fix * upd * upd Re-design Indices lowering (#29) * upd * upd * upd * upd * upd * init * format * fix * revise coding-style * format Complete indices lowering (#30) * upd * upd * upd * done * upd * passed test * upd Add more docstrings and depress warnings for new lowering algorithm. (#31) Refactor derived axis, frontend support of fusion. (#32) * upd * upd * fix Fatal bugfix and change the signature of DenseVariableAxis. (#33) Syntax simplification (#34) Change the order of generated blocks for block isolation. (#35) * upd * upd * upd Syntax of AttachAxis for BMM (#36) * upd * upd * upd [SparseTIR] Add "square sum" lowering test (#37) * Add square sum test * Remove pylint comment [BugFix] Fix offset caching in lowering (#38) * Hack compact dataflow check in a dirty way * Add two-K square sum test * Mark skipped tests * Fix offset saving in lowering Fusion syntax fix + SDDMM example. (#39) Some structure change on update offsets. (#40) [Refactor] SparseTIR Lowering (#41) * Take out methods in Scope * Refactor * Refactor "match" * Tweak scope contents * Refactor ViewIndexInAxis * Refactor Scope * SDDMM tests under implementation * Refactor block stack * Use Map for var_map * Extract NeedCreateNewBlock * Simplify SpIterVarToIterVar via GetIterExtent * Refactor NeedCreateNewBlock * Add docstring * Use "auto" correctly * Minor refactor and use some move Remove redundant analyzers (#42) Support indices lowering for attach and fuse. (#43) * upd * upd * upd Fix irregular BMM example. (#44) * upd * upd * upd * upd RGCN forward and butterfly pattern example. (#45) Fused SDDMM example. (#46) * upd * wip * fix Fix sparse reorder after refactor (#47) [Refactor] Refactor Unittest (#48) * upd * remove redundancy [Unittest] Correctness test for benchmarking scripts (#49) Bugfix and more test for axis fusion, new workload (#50) * upd * upd upd --- include/tvm/tir/builtin.h | 15 + include/tvm/tir/expr.h | 53 ++ include/tvm/tir/expr_functor.h | 4 + include/tvm/tir/op.h | 32 + include/tvm/tir/schedule/schedule.h | 56 ++ include/tvm/tir/schedule/state.h | 7 + include/tvm/tir/sparse.h | 559 +++++++++++++++ include/tvm/tir/stmt.h | 115 ++- include/tvm/tir/stmt_functor.h | 8 + include/tvm/tir/transform.h | 6 + python/tvm/script/context_maintainer.py | 20 +- python/tvm/script/parser.py | 10 + python/tvm/script/tir/intrin.py | 41 +- python/tvm/script/tir/scope_handler.py | 84 +++ python/tvm/script/tir/special_stmt.py | 211 ++++++ python/tvm/tir/__init__.py | 8 +- python/tvm/tir/_ffi_api.py | 1 + python/tvm/tir/expr.py | 22 + python/tvm/tir/op.py | 76 ++ python/tvm/tir/schedule/schedule.py | 92 ++- python/tvm/tir/sparse.py | 305 ++++++++ python/tvm/tir/stmt.py | 89 ++- python/tvm/tir/transform/transform.py | 12 +- src/printer/tvmscript_printer.cc | 218 ++++++ src/target/source/codegen_cuda.cc | 36 + src/target/source/codegen_cuda.h | 2 + .../source/literal/cuda_binary_search.h | 69 ++ src/tir/ir/expr.cc | 30 + src/tir/ir/expr_functor.cc | 14 + src/tir/ir/sparse.cc | 461 ++++++++++++ src/tir/ir/stmt.cc | 131 +++- src/tir/ir/stmt_functor.cc | 43 ++ src/tir/op/builtin.cc | 9 + src/tir/op/op.cc | 21 + src/tir/schedule/analysis.h | 11 + src/tir/schedule/analysis/analysis.cc | 20 + src/tir/schedule/concrete_schedule.cc | 50 ++ src/tir/schedule/concrete_schedule.h | 35 + src/tir/schedule/primitive.h | 14 + .../primitive/sparse_loop_transformation.cc | 174 +++++ src/tir/schedule/schedule.cc | 17 + src/tir/schedule/traced_schedule.cc | 24 + src/tir/schedule/traced_schedule.h | 4 + src/tir/transforms/lower_sparse_tir.cc | 676 ++++++++++++++++++ tests/python/sparsetir/bench_rgcn.py | 175 +++++ tests/python/sparsetir/bench_rgcn_new.py | 210 ++++++ tests/python/sparsetir/lowered_tir.py | 464 ++++++++++++ tests/python/sparsetir/sparse_tir_scripts.py | 361 ++++++++++ tests/python/sparsetir/test_butterfly.py | 38 + .../sparsetir/test_tir_sparse_buffer.py | 28 + .../sparsetir/test_tir_sparse_correctness.py | 403 +++++++++++ .../python/sparsetir/test_tir_sparse_lower.py | 119 +++ .../test_tir_sparse_nnz_inference.py | 70 ++ .../sparsetir/test_tir_sparse_schedule.py | 238 ++++++ .../test_tir_sparse_script_roundtrip.py | 186 +++++ .../sparsetir/test_tir_sparse_tensorize.py | 1 + tests/python/unittest/test_tir_intrin.py | 86 ++- 57 files changed, 6241 insertions(+), 23 deletions(-) create mode 100644 include/tvm/tir/sparse.h create mode 100644 python/tvm/tir/sparse.py create mode 100644 src/target/source/literal/cuda_binary_search.h create mode 100644 src/tir/ir/sparse.cc create mode 100644 src/tir/schedule/primitive/sparse_loop_transformation.cc create mode 100644 src/tir/transforms/lower_sparse_tir.cc create mode 100644 tests/python/sparsetir/bench_rgcn.py create mode 100644 tests/python/sparsetir/bench_rgcn_new.py create mode 100644 tests/python/sparsetir/lowered_tir.py create mode 100644 tests/python/sparsetir/sparse_tir_scripts.py create mode 100644 tests/python/sparsetir/test_butterfly.py create mode 100644 tests/python/sparsetir/test_tir_sparse_buffer.py create mode 100644 tests/python/sparsetir/test_tir_sparse_correctness.py create mode 100644 tests/python/sparsetir/test_tir_sparse_lower.py create mode 100644 tests/python/sparsetir/test_tir_sparse_nnz_inference.py create mode 100644 tests/python/sparsetir/test_tir_sparse_schedule.py create mode 100644 tests/python/sparsetir/test_tir_sparse_script_roundtrip.py create mode 100644 tests/python/sparsetir/test_tir_sparse_tensorize.py diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index d8a5ea67d844..1c92fc2f537d 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -494,6 +494,21 @@ TVM_DLL const Op& tvm_warp_shuffle_up(); TVM_DLL const Op& tvm_warp_shuffle_down(); TVM_DLL const Op& tvm_warp_activemask(); +/*! + * \brief Lower bound function for binary search. + */ +TVM_DLL const Op& tvm_lower_bound(); + +/*! + * \brief Upper bound function for binary search. + */ +TVM_DLL const Op& tvm_upper_bound(); + +/*! + * \brief Atomic add function. + */ +TVM_DLL const Op& tvm_atomic_add(); + /*! * \brief Initialize the global barrier. * Call this at beginning of kernel that need global barrier. diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index f6741112f269..b17db12d714d 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -643,6 +644,58 @@ class BufferLoad : public PrimExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; +/*! + * \brief Load value from the high dimension sparse buffer. + * + * \code + * + * value = buffer[i, j]; + * + * \endcode + * \sa SparseBufferStore + */ +class SparseBufferLoadNode : public PrimExprNode { + public: + /*! \brief The buffer to be loaded. */ + SparseBuffer buffer; + /*! \brief The indices location to be loaded. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("dtype", &(this->dtype)); + v->Visit("buffer", &buffer); + v->Visit("indices", &indices); + v->Visit("span", &span); + } + + bool SEqualReduce(const SparseBufferLoadNode* other, SEqualReducer equal) const { + return equal(dtype, other->dtype) && equal(buffer, other->buffer) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(dtype); + hash_reduce(buffer); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.SparseBufferLoad"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferLoadNode, PrimExprNode); +}; + +/*! + * \brief Managed reference to SparseBufferLoadNode. + * \sa SparseBufferLoadNode + */ +class SparseBufferLoad : public PrimExpr { + public: + TVM_DLL explicit SparseBufferLoad(SparseBuffer buffer, Array indices, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferLoad, PrimExpr, SparseBufferLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferLoadNode); +}; + /*! * \brief Load value from the result produced by the producer. * diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index b5f1d64a00c4..2507e734c7a7 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -119,6 +119,7 @@ class ExprFunctor { return VisitExpr_(static_cast(op), std::forward(args)...); } virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -165,6 +166,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); + IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode); IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode); @@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const SparseBufferLoadNode* op) override; void VisitExpr_(const ProducerLoadNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; + PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override; PrimExpr VisitExpr_(const ProducerLoadNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9c3ea135c68d..f9921b87c8c7 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -820,6 +820,38 @@ TVM_DLL PrimExpr round(PrimExpr x, Span span = Span()); */ TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span()); +/*! + * \brief Lower bound function for binary search + * \param arr The buffer variable of the array to be looked up in + * \param val The value to be looked up in the array + * \param l The left boundary of the look-up range (inclusive) + * \param r The right boundary of the look-up range (exclusive) + * \param span The location of this operation in the source + * \return The look-up result + */ +TVM_DLL PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, + Span span = Span()); + +/*! + * \brief Upper bound function for binary search + * \param arr The buffer variable of the array to be looked up in + * \param val The value to be looked up in the array + * \param l The left boundary of the look-up range (inclusive) + * \param r The right boundary of the look-up range (exclusive) + * \param span The location of this operation in the source + * \return The look-up result + */ +TVM_DLL PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, + Span span = Span()); + +/*! + * \brief Perform atomic add on ptr by val, and return the old value. + * \param ptr The address to perform atomic add. + * \param val The value to add. + * \return The old result stored in ptr. + */ +TVM_DLL PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span = Span()); + /*! * \brief Calculate trunc(x) * \param x The input expression. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index be06b44820cd..2f84674c4f00 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace tir { @@ -85,6 +86,27 @@ using ExprRV = PrimExpr; using ExprRVNode = PrimExprNode; +/**************** Random variable: SparseBlockRV ****************/ + +/*! \brief A random variable that evaluates to a TensorIR sparse block */ +class SparseBlockRVNode : public runtime::Object { + public: + void VisitAttrs(tvm::AttrVisitor* v) {} + static constexpr const char* _type_key = "tir.SparseBlockRV"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockRVNode, runtime::Object); +}; + +/*! + * \brief Managed reference to SparseBlockRVNode + * \sa SparseBlockRVNode + */ +class SparseBlockRV : public runtime::ObjectRef { + public: + /*! \brief Constructor */ + TVM_DLL SparseBlockRV(); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SparseBlockRV, runtime::ObjectRef, SparseBlockRVNode); +}; + /**************** The Schedule class ****************/ class Schedule; @@ -143,6 +165,12 @@ class ScheduleNode : public runtime::Object { * \return The corresponding expr */ virtual PrimExpr Get(const ExprRV& expr_rv) const = 0; + /*! + * \brief Get the sparse block corresponding to the specific random variable + * \param sp_block_rv The random variable to be looked up + * \return SparseBlock The corresponding sparse block + */ + virtual SparseBlock Get(const SparseBlockRV& sp_block_rv) const = 0; /*! * \brief Get the block sref corresponding to the specific BlockRV * \param block_rv The BlockRV to be looked up @@ -188,6 +216,11 @@ class ScheduleNode : public runtime::Object { * \param expr_rv The random variable to be removed */ virtual void RemoveRV(const ExprRV& expr_rv) = 0; + /*! + * \brief Remove an sparse block random variable from the symbol table + * \param sp_block_rv The random variable to be removed + */ + virtual void RemoveRV(const SparseBlockRV& sp_block_rv) = 0; public: /******** Schedule: Sampling ********/ @@ -524,6 +557,29 @@ class ScheduleNode : public runtime::Object { /******** Schedule: Misc ********/ /*! \brief A no-op that marks the start of postprocessing phase of scheduling */ virtual void EnterPostproc() = 0; + /******** Schedule: SparseTIR schedules ********/ + /*! + * \brief Retrieve a sparse block in a specific function with its name + * \param name The name of the sparse block to be retrieved + * \param func_name The name of the function + * \return The sparse block retrieved + * \note Indexing error is raised if 0 or multiple blocks exist with the specific name + */ + virtual SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") = 0; + /*! + * \brief Retrieve the sparse iterators of a given sparse block + * \param block_rv The block to be queried + * \return The sparse iterators of the input sparse block + */ + virtual Array GetSpIters(const SparseBlockRV& block_rv) = 0; + /*! + * \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator + * dependency. + * \param block The block to be transformed + * \param new_order The new order of the sparse iterators, whose length should equal to the number + * of the input block's sparse iterators + */ + virtual void SparseReorder(const SparseBlockRV& block_rv, const Array& new_order) = 0; }; /*! diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 201d78fe631c..976cea89b912 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -162,6 +162,13 @@ class ScheduleStateNode : public Object { * \return A boolean flag indicating if the block has quasi-affine bindings */ bool IsAffineBlockBinding(const StmtSRef& block_sref) const { + // (SparseTIR Hack) Always return true for sparse blocks. + const auto* block = block_sref->StmtAs(); + Optional sparse_attr = block != nullptr ? block->annotations.Get("sparse") : NullOpt; + if (sparse_attr.defined() && sparse_attr.as()->value == 1) { + return true; + } + return GetBlockInfo(block_sref).affine_binding; } /*! diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h new file mode 100644 index 000000000000..5a44dd19a705 --- /dev/null +++ b/include/tvm/tir/sparse.h @@ -0,0 +1,559 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief tvm/tir/sparse.h + * \brief sparse axes and buffers. + */ +#ifndef TVM_TIR_SPARSE_H_ +#define TVM_TIR_SPARSE_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +enum class AxisKind : int { + kDenseFixed = 0, + kDenseVariable = 1, + kSparseFixed = 2, + kSparseVariable = 3 +}; + +class Axis; + +/*! \brief Common interface for both SparseBlockCtx and SparseBufferAccessCtx. */ +class SparseCtx { + public: + virtual Optional GetPrevAxis(Axis axis) const = 0; + virtual PrimExpr GetCoordinate(Axis axis) const = 0; + virtual PrimExpr GetOffset(Axis axis) const = 0; + virtual void SetCoordinate(Axis axis, PrimExpr coordinate) = 0; + virtual void SetOffset(Axis axis, PrimExpr index) = 0; +}; + +/*! + * \brief Base type for axis in sparse formats. + */ +class AxisNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + } + + bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + } + + /* name of current axis. */ + String name; + /* length of current axis. For sparse axis, length refers to the upperbound of + * the current axis. */ + PrimExpr length; + + String GetName() const { return name; } + PrimExpr GetLength() const { return length; } + DataType GetIndexType() const { return length->dtype; } + virtual Optional GetParentAxis() const = 0; + + virtual AxisKind kind() const = 0; + virtual PrimExpr GetNNZ() const = 0; + + virtual PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const = 0; + virtual PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const = 0; + virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const = 0; + std::tuple GetOffsetExtent(SparseCtx* ctx) const; + + static constexpr const char* _type_key = "tir.sparse.Axis"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_BASE_OBJECT_INFO(AxisNode, Object); +}; + +/*! + * \brief Managed reference to AxisNode. + * \sa AxisNode + */ +class Axis : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(Axis, ObjectRef, AxisNode); +}; + +/*! + * \brief Dense axis whose column indices are consecutive. + */ +class DenseAxisNode : public AxisNode { + public: + static constexpr const char* _type_key = "tir.sparse.DenseAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(DenseAxisNode, AxisNode); +}; + +/*! + * \brief Managed reference to DenseAxisNode. + * \sa DenseAxisNode + */ +class DenseAxis : public Axis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DenseAxis, Axis, DenseAxisNode); +}; + +/*! + * \brief Sparse axis whose column indices is not consecutive. + */ +class SparseAxisNode : public AxisNode { + public: + static constexpr const char* _type_key = "tir.sparse.SparseAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(SparseAxisNode, AxisNode); +}; + +/*! + * \brief Managed reference to SparseAxisNode. + * \sa SparseAxisNode + */ +class SparseAxis : public Axis { + public: + TVM_DEFINE_OBJECT_REF_METHODS(SparseAxis, Axis, SparseAxisNode); +}; + +/*! + * \brief Dense axis with fixed length per row. + */ +class DenseFixedAxisNode : public DenseAxisNode { + public: + AxisKind kind() const final { return AxisKind::kDenseFixed; } + + PrimExpr GetNNZ() const final { return length; } + + Optional GetParentAxis() const final { return NullOpt; } + + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; + + PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; + + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; + + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); +}; + +/*! + * \brief Managed reference to DenseFixedAxisNode. + * \sa DenseFixedAxisNode + */ +class DenseFixedAxis : public DenseAxis { + public: + TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); +}; + +/*! \brief Derivation axis, constructed by T.dense(axis). */ +class DenseFromSparseAxisNode : public DenseFixedAxisNode { + public: + void VisitAttrs(AttrVisitor* v) { + DenseFixedAxisNode::VisitAttrs(v); + v->Visit("base", &base); + } + + bool SEqualReduce(const DenseFromSparseAxisNode* other, SEqualReducer equal) const { + return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(base, other->base); + } + + void SHashReduce(SHashReducer hash_reduce) const { + DenseFixedAxisNode::SHashReduce(hash_reduce); + hash_reduce(base); + } + + /* The based sparse axis. */ + SparseAxis base; + + static constexpr const char* _type_key = "tir.sparse.DenseFromSparseAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(DenseFromSparseAxisNode, DenseFixedAxisNode); +}; + +/*! + * \brief Managed reference of DenseFromSparseAxisNode. + * \sa DenseFromSparseAxisNode + */ +class DenseFromSparseAxis : public DenseFixedAxis { + public: + /* DenseFromSparseAxis could be constructed by specifying the based sparse axis. */ + TVM_DLL explicit DenseFromSparseAxis(SparseAxis base); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseFromSparseAxis, DenseFixedAxis, DenseFromSparseAxisNode); +}; + +class FusedAxis; + +/*! \brief Derivation axis, constructed by T.fuse(axis1, axis2, ...) */ +class FusedAxisNode : public DenseFixedAxisNode { + public: + /* The group of axes to be fused. */ + Array group; + /* The index of current FusedAxis in the group. */ + int index; + + void VisitAttrs(AttrVisitor* v) { + DenseFixedAxisNode::VisitAttrs(v); + v->Visit("group", &group); + v->Visit("index", &index); + } + + bool SEqualReduce(const FusedAxisNode* other, SEqualReducer equal) const { + return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(group, other->group) && + equal(index, other->index); + } + + void SHashReduce(SHashReducer hash_reduce) const { + DenseFixedAxisNode::SHashReduce(hash_reduce); + hash_reduce(group); + hash_reduce(index); + } + + static constexpr const char* _type_key = "tir.sparse.FusedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(FusedAxisNode, DenseFixedAxisNode); +}; + +/*! + * \brief Managed reference to FusedAxisNode. + * \sa FusedAxisNode + */ +class FusedAxis : public DenseFixedAxis { + public: + /* Fused axis could be constructed by specifying a group of based axes and an index */ + TVM_DLL explicit FusedAxis(Array group, int index); + + TVM_DEFINE_OBJECT_REF_METHODS(FusedAxis, DenseFixedAxis, FusedAxisNode); +}; + +/*! + * \brief Dense axis with variable length, such as ragged tensor. + */ +class DenseVariableAxisNode : public DenseAxisNode { + public: + Buffer indptr; + PrimExpr nnz_; + Axis parent_; + + void VisitAttrs(AttrVisitor* v) { + DenseAxisNode::VisitAttrs(v); + v->Visit("indptr", &indptr); + } + + bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { + return DenseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr); + } + + void SHashReduce(SHashReducer hash_reduce) const { + DenseAxisNode::SHashReduce(hash_reduce); + hash_reduce(indptr); + } + + AxisKind kind() const final { return AxisKind::kDenseVariable; } + + PrimExpr GetNNZ() const final { return nnz_; } + + Optional GetParentAxis() const final { return parent_; } + + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; + + PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; + + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; + + static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); +}; + +/*! + * \brief Managed reference to DenseVariableAxisNode. + * \sa DenseVariableAxisNode + */ +class DenseVariableAxis : public DenseAxis { + public: + TVM_DLL explicit DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, + Buffer indptr); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); +}; + +/*! + * \brief Dense variable axis attached to another dense variable axis. + */ +class AttachedAxisNode : public DenseVariableAxisNode { + public: + /* The original axis before attaching. */ + DenseVariableAxis orig_; + + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; + + static constexpr const char* _type_key = "tir.sparse.AttachedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttachedAxisNode, DenseVariableAxisNode); +}; + +/*! + * \brief Managed reference to AttachedAxisNode. + * \sa AttachedAxisNode + */ +class AttachedAxis : public DenseVariableAxis { + public: + TVM_DLL explicit AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr); + TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode); +}; + +/*! + * \brief Sparse axis with fixed number of non-zero columns per row. + */ +class SparseFixedAxisNode : public SparseAxisNode { + public: + Buffer indices; + Axis parent_; + /* fixed number of non-zero columns of current sparse axis. */ + PrimExpr nnz_cols; + + void VisitAttrs(AttrVisitor* v) { + SparseAxisNode::VisitAttrs(v); + v->Visit("indptr", &indices); + v->Visit("nnz_cols", &nnz_cols); + } + + bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { + return SparseAxisNode::SEqualReduce(other, equal) && equal(indices, other->indices) && + equal(nnz_cols, other->nnz_cols); + } + + void SHashReduce(SHashReducer hash_reduce) const { + SparseAxisNode::SHashReduce(hash_reduce); + hash_reduce(indices); + hash_reduce(nnz_cols); + } + + PrimExpr GetNNZ() const { return indices->shape[0]; } + + AxisKind kind() const final { return AxisKind::kSparseFixed; } + + Optional GetParentAxis() const final { return parent_; } + + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; + + PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; + + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; + + static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); +}; + +/*! + * \brief Managed reference to SparseFixedAxisNode. + * \sa SparseFixedAxisNode + */ +class SparseFixedAxis : public SparseAxis { + public: + TVM_DLL explicit SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode); +}; + +/*! + * \brief Sparse axis with variable number of non-zero columns per row. + */ +class SparseVariableAxisNode : public SparseAxisNode { + public: + Buffer indptr; + Buffer indices; + Axis parent_; + + void VisitAttrs(AttrVisitor* v) { + SparseAxisNode::VisitAttrs(v); + v->Visit("indptr", &indptr); + v->Visit("indices", &indices); + } + + bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { + return SparseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + SparseAxisNode::SHashReduce(hash_reduce); + hash_reduce(indptr); + hash_reduce(indices); + } + + PrimExpr GetNNZ() const { return indices->shape[0]; } + + AxisKind kind() const final { return AxisKind::kSparseVariable; } + + Optional GetParentAxis() const final { return parent_; } + + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; + + PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; + + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; + + static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); +}; + +/*! + * \brief Managed reference to SparseVariableAxisNode. + * \sa SparseVariableAxisNode + */ +class SparseVariableAxis : public SparseAxis { + public: + TVM_DLL explicit SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr, + Buffer indices); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); +}; + +/*! + * \brief Class of sparse buffer. + */ +class SparseBufferNode : public Object { + public: + /* Axes */ + Array axes; + /* Buffer corresponding to flattened value */ + Buffer data; + /* Buffer Name */ + String name; + + inline int ndim() const { return static_cast(axes.size()); } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("axes", &axes); + v->Visit("data", &data); + v->Visit("name", &name); + } + + bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { + return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(axes); + hash_reduce(data); + hash_reduce(name); + } + + static constexpr const char* _type_key = "tir.sparse.SparseBuffer"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object); +}; + +/*! + * \brief Managed reference to SparseBufferNode. + * \sa SparseBufferNode + */ +class SparseBuffer : public ObjectRef { + public: + TVM_DLL explicit SparseBuffer(Array axes, Buffer data, String name); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); +}; + +// overload printing of for type. +TVM_DLL std::ostream& operator<<(std::ostream& os, AxisKind kind); + +/*! + * \brief Iterator variables in SparseTIR + */ +class SpIterVarNode : public Object { + public: + Var var; + PrimExpr max_extent; + bool is_reduction; + Axis axis; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("max_extent", &max_extent); + v->Visit("axis", &axis); + v->Visit("is_reduction", &is_reduction); + } + + bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const { + return equal(var, other->var) && equal(max_extent, other->max_extent) && + equal(axis, other->axis) && equal(is_reduction, other->is_reduction); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(var); + hash_reduce(max_extent); + hash_reduce(axis); + hash_reduce(is_reduction); + } + + static constexpr const char* _type_key = "tir.sparse.SpIterVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SpIterVarNode, Object); +}; + +class SpIterVar : public ObjectRef { + public: + TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, bool is_reduction, Axis axis); + + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; + + TVM_DEFINE_OBJECT_REF_METHODS(SpIterVar, ObjectRef, SpIterVarNode); +}; + +// inline implementations +inline SpIterVar::operator PrimExpr() const { return (*this)->var; } + +// inline implementations +inline const char* SpIterKind2String(AxisKind t) { + switch (t) { + case AxisKind::kDenseFixed: + return "dense_fixed"; + case AxisKind::kDenseVariable: + return "dense_variable"; + case AxisKind::kSparseFixed: + return "sparse_fixed"; + case AxisKind::kSparseVariable: + return "sparse_variable"; + } + LOG(FATAL) << "Unknown AxisKind" << t; + throw; +} + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_SPARSE_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 972f78171569..b5ff3686c15e 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -327,6 +327,60 @@ class BufferStore : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; +/*! + * \brief Store value to the high dimension sparse buffer. + * + * \code + * + * buffer[i, j] = value; + * + * \endcode + * \sa SparseBufferStore + */ +class SparseBufferStoreNode : public StmtNode { + public: + /*! \brief The sparse buffer to be accessed. */ + SparseBuffer buffer; + /*! \brief The value to be stored. */ + PrimExpr value; + /*! \brief The indices location to be stored. */ + Array indices; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("value", &value); + v->Visit("indices", &indices); + v->Visit("span", &span); + } + + bool SEqualReduce(const SparseBufferStoreNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(value, other->value) && + equal(indices, other->indices); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(value); + hash_reduce(indices); + } + + static constexpr const char* _type_key = "tir.SparseBufferStore"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferStoreNode, StmtNode); +}; + +/*! + * \brief Managed reference to SparseBufferStoreNode. + * \sa SparseBufferStoreNode + */ +class SparseBufferStore : public Stmt { + public: + TVM_DLL explicit SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array indices, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBufferStore, Stmt, SparseBufferStoreNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBufferStoreNode); +}; + /*! * \brief Annotate the region where the buffer need to * be read and write in the body. @@ -1316,7 +1370,66 @@ class BlockRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); }; -/*! \brief namespace of possible attributes in AttrStmt.attr_key */ +/*! + * \brief Sparse Block node. + */ +class SparseBlockNode : public StmtNode { + public: + /*! \brief The sparse iteration variables of the block. */ + Array sp_iter_vars; + /*! \brief The sparse data structures */ + Array sp_structs; + /*! \brief The mapping from sparse data structures to the PrimFunc parameters */ + Map> sp_struct_param_map; + /*! \brief The name of the block */ + String name; + /*! \brief The body of the block */ + Stmt body; + /*! \brief The init statement of the block */ + Optional init; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("sp_iter_vars", &sp_iter_vars); + v->Visit("sp_structs", &sp_structs); + v->Visit("sp_struct_param_map", &sp_struct_param_map); + v->Visit("name", &name); + v->Visit("body", &body); + v->Visit("init", &init); + } + + bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const { + return equal(sp_iter_vars, other->sp_iter_vars) && equal(name, other->name) && + equal(body, other->body) && equal(init, other->init) && + equal(sp_structs, other->sp_structs); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(sp_iter_vars); + hash_reduce(name); + hash_reduce(body); + hash_reduce(init); + hash_reduce(sp_structs); + } + + static constexpr const char* _type_key = "tir.SparseBlock"; + TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode); +}; + +/*! + * \brief Managed reference to SparseBufferNode + * \sa SparseBufferNode + */ +class SparseBlock : public Stmt { + public: + TVM_DLL explicit SparseBlock(Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, + Optional init = NullOpt, Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode); +}; + +/*! \brief namespace of possible attribute sin AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. /*! \brief Mark launching extent of thread, used by device API. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 16da91c2a2a3..ba99b801f39b 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -90,6 +90,7 @@ class StmtFunctor { virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SparseBufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -99,6 +100,7 @@ class StmtFunctor { virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SparseBlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -123,9 +125,11 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); + IR_STMT_FUNCTOR_DISPATCH(SparseBufferStoreNode); IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); IR_STMT_FUNCTOR_DISPATCH(BlockNode); IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); + IR_STMT_FUNCTOR_DISPATCH(SparseBlockNode); return vtable; } }; @@ -160,6 +164,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const AllocateConstNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const SparseBufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const ProducerStoreNode* op) override; @@ -169,6 +174,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const BlockNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; + void VisitStmt_(const SparseBlockNode* op) override; }; /*! @@ -261,6 +267,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const AllocateConstNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const SparseBufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const ProducerStoreNode* op) override; @@ -270,6 +277,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const EvaluateNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; Stmt VisitStmt_(const BlockRealizeNode* op) override; + Stmt VisitStmt_(const SparseBlockNode* op) override; /*! * \brief Alternative advance method for SeqStmtNode. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 4330c4f7c64a..39fa7a82e662 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -617,6 +617,12 @@ TVM_DLL Pass ExtractPrimFuncConstants(); */ TVM_DLL Pass RenormalizeSplitPattern(); +/*! + * \brief Lower SparseTIR to TIR. + * \return The pass. + */ +TVM_DLL Pass LowerSparseTIR(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 149e17bcc701..1d9c7c03e13c 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -26,6 +26,7 @@ from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object from tvm.tir.expr import IterVar +from tvm.tir.sparse import Axis, SparseBuffer from .tir.node import BufferSlice @@ -119,7 +120,7 @@ class ContextMaintainer: """List[BlockInfo]: The block info for the current block scope""" loop_stack: Dict[Var, Range] = {} """Dict[Var, Range]: The dict from loop var to its domain outside the block""" - symbols: List[Dict[str, Union[Var, Buffer]]] = [] + symbols: List[Dict[str, Union[Var, Buffer, SparseBuffer, Axis]]] = [] """List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope""" # function context @@ -132,6 +133,12 @@ class ContextMaintainer: func_var_env_dict: Mapping[Var, str] = {} """Mapping[Var, str]: The map from var to env thread""" + # sparse block context + sp_struct: List[Object] = [] + """List[Object]: The sparse data structures""" + sp_struct_params: List[List[Var]] = [] + """List[List[Var]]: The function parameters that corresponding to each sparse data structures""" + # parser and analyzer analyzer: tvm.arith.Analyzer = tvm.arith.Analyzer() """tvm.arith.Analyzer: The analyzer for simplifying""" @@ -153,6 +160,9 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No self.func_buffer_map = {} self.func_dict_attr = {} self.func_var_env_dict = {} + # sparse block context + self.sp_struct = [] + self.sp_struct_params = [] # parser and analyzer self._report_error = _report_error self.analyzer = tvm.arith.Analyzer() @@ -208,9 +218,11 @@ def exit_block_scope(self): # Pop block_info self.block_info_stack.pop() - def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node): + def update_symbol( + self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node + ): """Append a symbol into current scope""" - if isinstance(symbol, Buffer): + if isinstance(symbol, (Buffer, SparseBuffer, Axis)): if name in self.symbols[0]: self.report_error("Duplicate Buffer name: " + symbol.name, node.span) self.symbols[0][name] = symbol @@ -225,7 +237,7 @@ def remove_symbol(self, name: str): return raise RuntimeError("Internal error of tvm script parser: no symbol named " + name) - def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]: + def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var, SparseBuffer, Axis]]: """Look up symbol by name""" for symbols in reversed(self.symbols): if name in symbols: diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 587fbe44a174..ddbcd4f97e81 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -617,6 +617,14 @@ def transform_SubscriptAssign(self, node): indexes, span=tvm_span_from_synr(node.span), ) + elif isinstance(symbol, tvm.tir.sparse.SparseBuffer): + # SparseBufferStore + return tvm.tir.SparseBufferStore( + symbol, + tvm.runtime.convert(rhs, span=rhs_span), + indexes, + span=tvm_span_from_synr(node.span), + ) else: if symbol.dtype == "handle" and len(indexes) != 1: self.report_error( @@ -966,6 +974,8 @@ def transform_Subscript(self, node): return BufferSlice( symbol, indexes, self.report_error, span=tvm_span_from_synr(node.span) ) + elif isinstance(symbol, tvm.tir.sparse.SparseBuffer): + return tvm.tir.SparseBufferLoad(symbol, indexes, span=tvm_span_from_synr(node.span)) elif isinstance(symbol, tvm.container.Array): if len(indexes) > 1: self.report_error( diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index d31e93c72b15..119fc77db7be 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -17,9 +17,20 @@ """TVM Script Parser Intrinsic Classes""" # pylint: disable=redefined-builtin, relative-beyond-top-level import builtins -from typing import List, Any +from typing import List, Optional, Any import tvm.tir +from tvm.ir import Span +from tvm.tir.sparse import ( + Axis, + DenseFixedAxis, + DenseVariableAxis, + SpIterVar, + SparseFixedAxis, + SparseVariableAxis, + DenseFromSparseAxis, + FusedAxis +) from ..registry import register from ..utils import get_param_list, tvm_span_from_synr @@ -111,6 +122,21 @@ def max_value(dtype, span): return tvm.tir.max_value(dtype, span) +@register +def lower_bound(arr, val, l, r, span): + return tvm.tir.lower_bound(arr, val, l, r, span) + + +@register +def upper_bound(arr, val, l, r, span): + return tvm.tir.upper_bound(arr, val, l, r, span) + + +@register +def atomic_add(ptr, val, span): + return tvm.tir.atomic_add(ptr, val, span) + + @register def floordiv(x, y, span): return tvm.tir.floordiv(x, y, span) @@ -234,3 +260,16 @@ def comm_reducer(lambda_io, identities, span): lambda_output = (lambda_output,) return tvm.tir.CommReducer(x, y, lambda_output, identities, span) + + +@register +def dense(axis: Axis, span: Optional[Span] = None): + if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): + return DenseFromSparseAxis(axis) + else: + return axis + + +@register +def fuse(*group: List[Axis], span: Optional[Span] = None): + return [FusedAxis(group, i) for i, _ in enumerate(group)] diff --git a/python/tvm/script/tir/scope_handler.py b/python/tvm/script/tir/scope_handler.py index 07ba20423161..6582ee22f72b 100644 --- a/python/tvm/script/tir/scope_handler.py +++ b/python/tvm/script/tir/scope_handler.py @@ -20,10 +20,12 @@ import synr import numpy as np +from synr.ast import With import tvm.tir from tvm.runtime import Object, String, convert from tvm.ir import Span, Range from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind +from tvm.tir.sparse import SpIterVar, Axis from .node import BufferSlice from .utils import buffer_slice_to_region @@ -373,6 +375,88 @@ def enter_scope( ) +@register +class SparseBlock(WithScopeHandler): + """With scope handler of SparseBlock""" + + def __init__(self): + def iter(axes: List, iter_types: str, name: str = "", span: Optional[Span] = None): + + # flatten nested axes to axes, to address the special case of fusion. + def flatten_axes(axes: List[Union[Axis, List[Axis]]]) -> List[Axis]: + ret = [] + for axis_group in axes: + if isinstance(axis_group, List): + ret += axis_group + else: + ret.append(axis_group) + return ret + + assert ( + self.node and self.context and self.body + ), "call 'exit_scope' before 'enter_scope'" + block_info = self.context.block_info_stack[-1] + axes = flatten_axes(axes) + + if len(axes) != len(self.sp_iters): + self.context.report_error( + "Inconsistent number of sparse iteration variable names, " + + f"there are {len(axes)} iterators but {len(self.sp_iters)} names. " + + "The number of sparse iteration variable names should match the number of iterators.", + self.node.span, + ) + if len(axes) != len(iter_types): + self.context.report_error( + "Inconsistent number of sparse iteration variable types, " + + f"there are {len(axes)} iterators but {len(iter_types)} types. " + + "The number of sparse iteration variable types should match the number of iterators.", + self.node.span, + ) + + sp_iters: List[SpIterVar] = [] + for i, axis in enumerate(axes): + is_reduction = True if iter_types[i] == "R" else False + sp_iters.append( + SpIterVar( + self.sp_iters[i], + axis.length, + is_reduction, + axis, + ) + ) + + block = tvm.tir.SparseBlock( + sp_iters, + self.context.sp_struct, + self.context.sp_struct_params, + name, + self.body, + block_info.init, + span, + ) + return block + + super().__init__(func=iter, concise_scope=False, def_symbol=True) + self.sp_iters = None + + def enter_scope( + self, + node: synr.ast.Node, + context: ContextMaintainer, + arg_list: List[Any], + span: synr.ast.Span, + ): + # define sparse iteration variables + assert isinstance( + node, synr.ast.With + ), f"SparseBlockScopeHandler expected to work on synr.ast.With but got {type(node)}" + + vars = WithScopeHandler.get_optional_vars(node, context) + self.sp_iters = [tvm.te.var(var.id.name, "int32") for var in vars] + for sp_iter in self.sp_iters: + context.update_symbol(sp_iter.name, sp_iter, node) + + @register class InitBlock(WithScopeHandler): """With scope handler T.init()""" diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 20161ad106c1..e0caa2941d90 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -17,6 +17,7 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level +from os import name from typing import Callable, List, Optional, Tuple, Any, Mapping, Union import synr @@ -29,6 +30,14 @@ from tvm.target import Target from tvm.ir import Span from tvm.tir import IntImm, IterVar +from tvm.tir.sparse import ( + Axis, + DenseFixedAxis, + DenseVariableAxis, + SparseFixedAxis, + SparseVariableAxis, + AttachedAxis, +) from .node import BufferSlice from .utils import buffer_slice_to_region @@ -885,3 +894,205 @@ def __call__(self, target_config): f"T.target expected a config dict or string, but got {type(target_config)}" ) return Target(target_config) + + +@register +class DenseFixed(SpecialStmt): + """Special Stmt for creating dense fixed axis.""" + + def __init__(self): + def dense_fixed(length: PrimExpr, span: Optional[Span] = None): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`dense_fixed` expected assign to only one var, but got {names}", span + ) + + axis = DenseFixedAxis(names[0], length) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([]) + self.context.update_symbol(names[0], axis, self.node) + + super().__init__(dense_fixed, def_symbol=True) + + +@register +class DenseVariable(SpecialStmt): + """Special Stmt for creating dense variable axis.""" + + def __init__(self): + def dense_variable( + parent_axis: Axis, + shape: Tuple[PrimExpr, PrimExpr], + indptr_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`dense_variable` expected assign to only one var, but got {names}", span + ) + + length, nnz = shape + indptr_len = parent_axis.nnz + 1 + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span + ) + axis = DenseVariableAxis(names[0], parent_axis, length, nnz, indptr_buf) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var]) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) + + super().__init__(dense_variable, def_symbol=True) + + +@register +class Attach(SpecialStmt): + """Special Stmt for attaching axis.""" + + def __init__(self): + def attach_axis( + parent: Axis, + orig: DenseVariableAxis, + nnz: PrimExpr, + indptr_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`attach_axis` expected assign to only one var, but got {names}", span + ) + + indptr_len = orig.parent.length + 1 + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span + ) + axis = AttachedAxis(names[0], parent, orig, nnz, indptr_buf) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var]) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) + + super().__init__(attach_axis, def_symbol=True) + + +@register +class SparseFixed(SpecialStmt): + """Special Stmt for creating sparse fixed axis.""" + + def __init__(self): + def sparse_fixed( + parent_axis: Axis, + shape: Tuple[PrimExpr, PrimExpr], + indices_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`sparse_fixed` expected assign to only one var, but got {names}", span + ) + + length, nnz_cols = shape + nnz = parent_axis.nnz * nnz_cols + indices_buf = tvm.tir.decl_buffer( + (nnz,), dtype=idtype, name=names[0] + "_indices", span=span + ) + axis = SparseFixedAxis(names[0], parent_axis, length, indices_buf, nnz_cols) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indices_var]) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) + + super().__init__(sparse_fixed, def_symbol=True) + + +@register +class SparseVariable(SpecialStmt): + """Special Stmt for creating sparse variable axis:""" + + def __init__(self): + def sparse_variable( + parent_axis: Axis, + shape: Tuple[PrimExpr, PrimExpr], + data: Tuple[tvm.tir.Var, tvm.tir.Var], + idtype: str = "int32", + span: Optional[Span] = None, + ): + names = [x.id.name for x in self.node.lhs] + if len(names) != 1: + self.context.report_error( + f"`sparse_variable` expected assign to only one var, but got {names}", span + ) + + length, nnz = shape + indptr_len = parent_axis.nnz + 1 + indptr_var, indices_var = data + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span + ) + indices_buf = tvm.tir.decl_buffer( + (nnz,), dtype=idtype, name=names[0] + "_indices", span=span + ) + axis = SparseVariableAxis(names[0], parent_axis, length, indptr_buf, indices_buf) + self.context.sp_struct.append(axis) + self.context.sp_struct_params.append([indptr_var, indices_var]) + self.context.update_symbol(names[0], axis, self.node) + self.context.update_symbol(names[0] + "_indptr", indptr_buf, self.node) + self.context.update_symbol(names[0] + "_indices", indices_buf, self.node) + + super().__init__(sparse_variable, def_symbol=True) + + +@register +class MatchSparseBuffer(SpecialStmt): + """Special Stmt match_sparse_buffer()""" + + def __init__(self): + def match_sparse_buffer( + param: tvm.tir.Var, + axes: List[Axis], + dtype: str = "float32", + span: Optional[Span] = None, + ): + def infer_nnz(axes: List[Axis]) -> PrimExpr: + """Inference the number of non-zero elements in a sparse buffer.""" + ret = axes[0].nnz + for axis in axes[1:]: + if isinstance(axis, DenseFixedAxis): + ret = ret * axis.nnz + else: + ret = axis.nnz + return ret + + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`match_sparse_buffer` must be assigned to a single sparse buffer, " + "e.g. A = match_sparse_buffer(...)" + ) + + buffer_name: str = self.node.lhs[0].id.name + if not isinstance(param, tvm.tir.Var): + self.context.report_error( + "The source of match_sparse_buffer expected Var, but got" + str(type(param)), + self.node.rhs.params[0].span, + ) + + if param in self.context.func_params: + data = tvm.tir.decl_buffer(infer_nnz(axes), dtype, buffer_name + "_data", span=span) + buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name) + self.context.sp_struct.append(buffer) + self.context.sp_struct_params.append([param]) + self.context.update_symbol(buffer_name + "_data", data, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) + else: + self.context.report_error( + "Can not bind non-input param to sparse buffer", self.node.rhs.params[0].span + ) + + super().__init__(match_sparse_buffer, def_symbol=True) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 17f9aa3d9c60..8fa8d0339ad1 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -24,13 +24,14 @@ from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not -from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle +from .expr import Select, BufferLoad, SparseBufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, CommReducer, Any from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For, While from .stmt import ( BufferStore, BufferRealize, + SparseBufferStore, Store, ProducerStore, Allocate, @@ -40,12 +41,12 @@ from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list -from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize +from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize, SparseBlock from .function import PrimFunc, TensorIntrin from .op import call_packed, call_intrin, call_pure_extern, call_extern -from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace +from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace, lower_bound, upper_bound, atomic_add from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz from .op import sin, sinh, asin, asinh from .op import cos, cosh, acos, acosh @@ -65,3 +66,4 @@ from . import analysis from . import stmt_functor from . import usmp +from . import sparse diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py index 1b60b8c81c6d..b6e939c6b533 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tir/_ffi_api.py @@ -19,3 +19,4 @@ tvm._ffi._init_api("tir", __name__) +tvm._ffi._init_api("tir.sparse", __name__) diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index 27cf5351a077..4ee6d9505ee9 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -1058,6 +1058,28 @@ def __init__(self, buffer, indices, span=None): ) +@tvm._ffi.register_object("tir.SparseBufferLoad") +class SparseBufferLoad(PrimExprWithOp): + """SparseBufferLoad node. + + Parameters + ---------- + buffer : SparseBuffer + The buffer to be loaded. + + indices : List[PrimExpr] + The indices location to be loaded. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer, indices, span=None): + self.__init_handle_by_constructor__( + _ffi_api.SparseBufferLoad, buffer, indices, span # type: ignore + ) + + @tvm._ffi.register_object("tir.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index de3ca5fa8d5b..4afe4fe055bf 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -971,6 +971,82 @@ def ldexp(x1, x2): return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore +def lower_bound(arr, val, l, r, span=None): + """Return the position to the first element in the arr[l:r] that is no less than val. + + Parameters + ---------- + arr : Var + Pointer to the 1D buffer to apply binary search on. + + val : PrimExpr + Value of the lower bound to search for in the buffer. + + l : PrimExpr + Start position to search for in the buffer. + + r : PrimExpr + End position to search for in the buffer. + + span : Optional[Span] + The location of this expression in the source code. + + Returns + ------- + PrimExpr + The index of element in arr[l:r] that is no less then given value. + """ + return _ffi_api.lower_bound(arr, val, l, r, span) # type: ignore + + +def upper_bound(arr, val, l, r, span=None): + """Return the position the first element in the arr that is greater than val. + + Parameters + ---------- + arr : Var + Pointer to the 1D buffer to apply binary search on. + + val : PrimExpr + Value of the upper bound to search for in the buffer. + + l : PrimExpr + Start position to search for in the buffer. + + r : PrimExpr + End position to search for in the buffer. + + span : Optional[Span] + The location of this expression in the source code. + + Returns + ------- + PrimExpr + The index of element in arr[l:r] that is no less then given value. + """ + return _ffi_api.upper_bound(arr, val, l, r, span) # type: ignore + + +def atomic_add(ptr, val, span=None): + """Perform an atomic add operation to ptr by the given val. + + Parameters + ---------- + ptr : Var + The pointer to the address we perform atomic add. + val : PrimExpr + The value to add. + span : Optional[Span] + The location of this expression in the source code. + + Returns + ------- + PrimExpr + The value on pointer before we perform the atomic add. + """ + return _ffi_api.atomic_add(ptr, val, span) # type: ignore + + def isnan(x, span=None): """Check if input value is Nan. diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 96fa21f30020..f78c0307bf19 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -21,7 +21,8 @@ from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc +from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, SparseBlock +from tvm.tir.sparse import SpIterVar from . import _ffi_api from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod @@ -56,12 +57,23 @@ def __init__(self) -> None: ) +@_register_object("tir.SparseBlockRV") +class SparseBlockRV(Object): + """A random variable that refers to a sparse block""" + + def __init__(self) -> None: + """Construct a new SparseBlockRV.""" + self.__init_handle_by_constructor__( + _ffi_api.SparseBlockRV # type: ignore # pylint: disable=no-member + ) + + # It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 # This feature is not supported until python 3.10: # https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer -RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # pylint: disable=invalid-name +RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV, SparseBlockRV] # pylint: disable=invalid-name # Update to `Literal["detail", "fast", "none"]` once upgraded to python3.8 _ERROR_RENDER_LEVEL: Dict[str, int] = { @@ -227,7 +239,7 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str: Parameters ---------- - rand_var : Union[ExprRV, BlockRV, LoopRV] + rand_var : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV] The random variable to be evaluated Returns @@ -243,22 +255,23 @@ def show(self, rand_var: RAND_VAR_TYPE) -> str: def get( self, rand_var_or_sref: Union[RAND_VAR_TYPE, StmtSRef], - ) -> Optional[Union[int, Block, For]]: + ) -> Optional[Union[int, Block, For, SparseBlock]]: """Returns: - the corresponding Block that a BlockRV evaluates to; - the corresponding For that a LoopRV evaluates to; - the corresponding integer that a ExprRV evaluates to; + - the corresponding SparseBlock that a SparseBlockRV evaluates to; - the corresponding Block that a block sref points to; - the corresponding For that a loop sref points to; Parameters ---------- - rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, StmtSRef] + rand_var_or_sref : Union[ExprRV, BlockRV, LoopRV, SparseBlockRV, StmtSRef] The random variable / sref to be evaluated Returns ------- - result : Optional[Union[int, Block, For]] + result : Optional[Union[int, Block, For, SparseBlock]] The corresponding result """ if isinstance(rand_var_or_sref, StmtSRef): @@ -296,7 +309,7 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: Parameters ---------- - rand_var : Union[BlockRV, LoopRV, ExprRV] + rand_var : Union[BlockRV, LoopRV, ExprRV, SparseBlockRV] The random variable to be removed """ return _ffi_api.ScheduleRemoveRV(self, rand_var) # type: ignore # pylint: disable=no-member @@ -2117,3 +2130,68 @@ def after_unannotate(a: T.handle, b: T.handle) -> None: def enter_postproc(self) -> None: """A no-op that marks the start of postprocessing phase of scheduling""" _ffi_api.ScheduleEnterPostproc(self) # type: ignore # pylint: disable=no-member + + ########## Schedule: SparseTIR schedules ########## + + def get_sparse_block( + self, + name: str, + func_name: str = "main", + ) -> SparseBlock: + """Retrieve a sparse block in a specific function with its name + + Parameters + ---------- + name : str + The name of the sparse block + func_name : str = "main" + The name of the function + + Returns + ------- + block : SparseBlockRV + The sparse block retrieved + IndexError is raised if 0 or multiple blocks exist with the specific name. + """ + return _ffi_api.ScheduleGetSparseBlock( # type: ignore # pylint: disable=no-member + self, + name, + func_name, + ) + + def get_sp_iters(self, block: SparseBlockRV) -> List[SpIterVar]: + """Retrieve the sparse iterators of a given sparse block + + Parameters + ---------- + block : SparseBlockRV + The block to be queried + + Returns + ------- + sp_iters : List[SpIterVar] + The sparse iterators of the input sparse block + """ + return _ffi_api.ScheduleGetSpIters( # type: ignore # pylint: disable=no-member + self, + block, + ) + + def sparse_reorder(self, block: SparseBlockRV, new_order: List[SpIterVar]) -> None: + """Reorder a list of sparse iterators. It requires the new order to not break the iterator + dependency. + + Parameters + ---------- + block : SparseBlockRV + The queried sparse block + + new_order : List[SpIterVar] + The The new order of the sparse iterators, whose length should equal to the number + of the input block's sparse iterators + """ + return _ffi_api.ScheduleSparseReorder( # type: ignore # pylint: disable=no-member + self, + block, + new_order, + ) diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py new file mode 100644 index 000000000000..df3a3cb82c0b --- /dev/null +++ b/python/tvm/tir/sparse.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SparseTIR axes and SparseBuffer +""" +from typing import Dict, List, Optional + +import tvm._ffi +from tvm.ir import PrimExpr +from tvm.runtime import Object +from tvm.tir import Var + +from . import _ffi_api +from .buffer import Buffer + + +@tvm._ffi.register_object("tir.sparse.Axis") +class Axis(Object): + """Base class of all the sparse axes.""" + + @property + def name(self): + return _ffi_api.GetAxisName(self) + + @property + def length(self): + return _ffi_api.GetAxisLength(self) + + @property + def idtype(self): + return _ffi_api.GetAxisIndexType(self) + + @property + def nnz(self): + return _ffi_api.GetNNZ(self) + + @property + def parent(self): + return _ffi_api.GetParent(self) + + +@tvm._ffi.register_object("tir.sparse.DenseAxis") +class DenseAxis(Axis): + pass + + +@tvm._ffi.register_object("tir.sparse.SparseAxis") +class SparseAxis(Axis): + pass + + +@tvm._ffi.register_object("tir.sparse.DenseFixedAxis") +class DenseFixedAxis(DenseAxis): + """DenseFixedAxis node + + Parameters + ---------- + name : str + The name of the axis + + length : PrimExpr + The length of the axis + """ + + name: str + length: PrimExpr + + def __init__(self, name, length): + self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.DenseFromSparseAxis") +class DenseFromSparseAxis(DenseFixedAxis): + """DenseFromSparseAxis node + + Parameters + ---------- + base : Axis + The based sparse axis. + """ + + base: Axis + + def __init__(self, base): + self.__init_handle_by_constructor__(_ffi_api.DenseFromSparseAxis, base) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.FusedAxis") +class FusedAxis(DenseFixedAxis): + """FusedAxis node + + Parameters + ---------- + group : List[Axis] + The axes group to be fused. + index : int + The index of current axis in the fused axes group. + """ + + group: List[Axis] + index: int + + def __init__(self, group, index): + self.__init_handle_by_constructor__(_ffi_api.FusedAxis, group, index) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.DenseVariableAxis") +class DenseVariableAxis(DenseAxis): + """DenseVariableAxis node + + Parameters + ---------- + name : str + The name of the axis + + parent : Axis + The parent axis + + length : PrimExpr + The length of the axis + + indptr : Buffer + The indptr buffer of the axis + """ + + name: str + parent: Axis + length: PrimExpr + nnz: PrimExpr + indptr: Buffer + + def __init__(self, name, parent, length, nnz, indptr): + self.__init_handle_by_constructor__( + _ffi_api.DenseVariableAxis, name, parent, length, nnz, indptr # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.AttachedAxis") +class AttachedAxis(DenseVariableAxis): + """AttachedAxis node + + Parameters + ---------- + name : str + The name of the axis. + parent : Axis + The axis to attach to. + orig : Axis + The axis to be attached. + nnz : PrimExpr + The number of nonzeros of the returned axis. + indptr : PrimExpr + The new indptr array of the the returned axis. + """ + + name : str + parent : Axis + orig : Axis + nnz : PrimExpr + indptr : PrimExpr + + def __init__(self, name, parent, orig, nnz, indptr): + self.__init_handle_by_constructor__( + _ffi_api.AttachedAxis, name, parent, orig, nnz, indptr + ) + + +@tvm._ffi.register_object("tir.sparse.SparseFixedAxis") +class SparseFixedAxis(DenseAxis): + """SparseFixedAxis node + + Parameters + ---------- + name : str + The name of the axis + + parent : Axis + The parent axis + + length : PrimExpr + The length of the axis + + indices : Buffer + The indices buffer of the axis + + nnz_cols : PrimExpr + The fixed number of non-zero elements along the axis + """ + + name: str + parent: Axis + length: PrimExpr + indices: Buffer + nnz_cols: PrimExpr + + def __init__(self, name, parent, length, indices, nnz_cols): + self.__init_handle_by_constructor__( + _ffi_api.SparseFixedAxis, name, parent, length, indices, nnz_cols # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.SparseVariableAxis") +class SparseVariableAxis(DenseAxis): + """SparseVariableAxis node + + Parameters + ---------- + name : str + The name of the axis + + parent : Axis + The parent axis + + length : PrimExpr + The length of the axis + + indptr : Buffer + The indptr buffer of the axis + + indices : Buffer + The indices buffer of the axis + """ + + name: str + parent: Axis + length: PrimExpr + indptr: Buffer + indices: Buffer + + def __init__(self, name, parent, length, indptr, indices): + self.__init_handle_by_constructor__( + _ffi_api.SparseVariableAxis, name, parent, length, indptr, indices # type: ignore + ) + + +@tvm._ffi.register_object("tir.sparse.SparseBuffer") +class SparseBuffer(Object): + """SparseBuffer node + + Parameters + ---------- + axes : List[Axis] + The axes of the sparse buffer + + data : Buffer + The data of the sparse buffer + + name : str + The name of the sparse buffer + """ + + axes: List[Axis] + data: Buffer + name: str + + def __init__(self, axes, data, name): + self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, axes, data, name) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.SpIterVar") +class SpIterVar(Object): + """IterVar in SparseTIR + + Parameters + ---------- + var : Var + The var of the SpIterVar + + max_extent : PrimExpr + The maximum extent of the SpIterVar + + is_reduction : bool + Whether the SpIterVar is a reduction iterator + + axis : Axis + The axis over which the SpIterVar iterates + """ + + var: Var + max_extent: PrimExpr + is_reduction: bool + axis: Axis + + DenseFixed = 0 + DenseVariable = 1 + SparseFixed = 2 + SparseVariable = 3 + + def __init__(self, var, max_extent, is_reduction, axis): + self.__init_handle_by_constructor__( + _ffi_api.SpIterVar, var, max_extent, is_reduction, axis # type: ignore + ) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 39831459f344..4c54b556d254 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -35,7 +35,8 @@ from . import _ffi_api from .buffer import Buffer -from .expr import IterVar +from .expr import Var, IterVar +from .sparse import SpIterVar, SparseBuffer class Stmt(Object): @@ -244,6 +245,31 @@ def __init__(self, buffer, value, indices, span=None): ) +@tvm._ffi.register_object("tir.SparseBufferStore") +class SparseBufferStore(Stmt): + """SparseBufferStore node. + + Parameters + ---------- + buffer : SparseBuffer + The sparse buffer to be accessed. + + value : PrimExpr + The value to be stored. + + indices : List[PrimExpr] + The indices location to be stored. + + span : Optional[Span] + The location of this itervar in the source code. + """ + + def __init__(self, buffer, value, indices, span=None): + self.__init_handle_by_constructor__( + _ffi_api.SparseBufferStore, buffer, value, indices, span # type: ignore + ) + + @tvm._ffi.register_object("tir.BufferRealize") class BufferRealize(Stmt): """Buffer realize node. @@ -648,6 +674,67 @@ def __init__( ) # type: ignore +@tvm._ffi.register_object("tir.SparseBlock") +class SparseBlock(Stmt): + """SparseBlock node. + + Parameters + ---------- + sp_iter_vars : List[SpIterVar] + The sparse iteration variables of the block. + + sp_struct : List[Object] + The sparse data structures + + sp_struct_params : List[List[Var]] + The function parameters that corresponding to each sparse data structures + + sp_struct2param_map : Mapping[Object, List[Var]] + The mapping from sparse data structures to the PrimFunc parameters. + + name : str + The name of the block. + + body : Stmt + The body of the block. + + init : Optional[Stmt] + The init statement of the block. + + span : Optional[Span] + The location of this block in the source code. + """ + + sp_iter_vars: List[SpIterVar] + sp_struct: List[Object] + sp_struct2param_map: Mapping[Object, List[Var]] + name: str + body: Stmt + init: Optional[Stmt] + span: Optional[Span] + + def __init__( + self, + sp_iter_vars: List[SpIterVar], + sp_struct: List[Object], + sp_struct_params: List[List[Var]], + name: str, + body: Stmt, + init: Optional[Stmt] = None, + span: Optional[Span] = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.SparseBlock, # type: ignore + sp_iter_vars, + sp_struct, + sp_struct_params, + name, + body, + init, + span, + ) # type: ignore + + @tvm._ffi.register_object("tir.BlockRealize") class BlockRealize(Stmt): """BlockRealize node. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 74e1f70121ef..708a7cd9b9af 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -764,7 +764,6 @@ def ConvertForLoopsToSerial(): def InjectSoftwarePipeline(): """Transform annotated loops into pipelined one that parallelize producers and consumers - Returns ------- fpass : tvm.transform.Pass @@ -793,3 +792,14 @@ def RenomalizeSplitPattern(): The result pass """ return _ffi_api.RenormalizeSplitPattern() # type: ignore + + +def LowerSparseTIR(): + """Lower SparseTIR to TIR + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerSparseTIR() # type: ignore diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index e1ccd2f5e428..54bbb4ee0719 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -173,6 +173,10 @@ class TVMScriptPrinter : public StmtFunctor, std::unordered_map memo_buf_; /*! \brief Map from Buffer to Declaration Doc */ std::unordered_map memo_buf_decl_; + /*! \brief Map from SparseBuffer to Doc */ + std::unordered_map memo_sp_buf_; + /*! \brief Map from Axis in SparseTIR to Doc */ + std::unordered_map memo_sp_axis_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; /*! \brief number of children of current node's parent */ @@ -228,6 +232,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override; + Doc VisitExpr_(const SparseBufferLoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override; @@ -242,6 +247,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const SparseBufferStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const AllocateConstNode* op) override; @@ -252,6 +258,7 @@ class TVMScriptPrinter : public StmtFunctor, Doc VisitStmt_(const PrefetchNode* op) override; Doc VisitStmt_(const EvaluateNode* op) override; Doc VisitStmt_(const BlockRealizeNode* op) override; + Doc VisitStmt_(const SparseBlockNode* op) override; Doc VisitStmtDefault_(const Object* op) override; Doc VisitType_(const PrimTypeNode* node) override; @@ -266,6 +273,8 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintArray(const ArrayNode* op); Doc PrintBuffer(const BufferNode* op); Doc PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body); + Doc PrintSparseBuffer(const SparseBufferNode* op); + Doc PrintSpAxis(const AxisNode* op); Doc AllocBufferDeclaration(const Buffer& buf); Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); Doc PrintBlockVarRemaps(); @@ -274,6 +283,9 @@ class TVMScriptPrinter : public StmtFunctor, Doc PrintExpandedArray(const ArrayNode* op); Doc PrintBlockBody(const BlockNode* op); virtual Doc PrintBlockName(const BlockNode* block_op); + Doc PrintSparseBlockName(const SparseBlockNode* op); + Doc PrintSparseStructDefinitions(const SparseBlockNode* sp_block); + Doc PrintBufferRegion(const BufferRegionNode* op); Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); Doc PrintCommReducer(const CommReducerNode* op); @@ -284,6 +296,8 @@ class TVMScriptPrinter : public StmtFunctor, Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + Doc AllocSparseBuf(const SparseBuffer& buffer); + Doc AllocAxis(const Axis& axis); void TryDeallocVar(const Var& var); bool ContainsOptionalInfo(const Stmt& stmt); /*! @@ -520,6 +534,44 @@ Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TVMScriptPrinter::AllocSparseBuf(const SparseBuffer& buffer) { + const auto& it = memo_sp_buf_.find(buffer); + if (it != memo_sp_buf_.end()) { + return it->second; + } + std::string name = buffer->name; + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "buf_" + name; + } + Doc val = GetUniqueName(name); + memo_sp_buf_[buffer] = val; + return val; +} + +Doc TVMScriptPrinter::AllocAxis(const Axis& axis) { + const auto& it = memo_sp_axis_.find(axis); + if (it != memo_sp_axis_.end()) { + return it->second; + } + Doc val; + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + // DenseFromSparseAxis is a temporally defined axis. + val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")"); + } else if (axis.as()) { + // FusedAxis is also a temporally defined axis. + CHECK(false) << "Cannot allocate fused axis"; + } else { + std::string name = axis->name; + if (name.length() == 0 || !std::isalnum(name[0])) { + name = "axis_" + name; + } + val = GetUniqueName(name); + } + + memo_sp_axis_[axis] = val; + return val; +} + /*! * \brief Check if any optional information exists in annotate_ for * a given Stmt. @@ -681,6 +733,10 @@ Doc TVMScriptPrinter::Print(const ObjectRef& node) { return PrintArray(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintSparseBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintSpAxis(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { @@ -833,6 +889,13 @@ Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_p return doc; } +Doc TVMScriptPrinter::VisitExpr_(const SparseBufferLoadNode* op, ExprPrecedence* out_precedence) { + *out_precedence = ExprPrecedence::kIdentity; + Doc doc; + doc << Print(op->buffer) << Print(op->indices); + return doc; +} + Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) { *out_precedence = ExprPrecedence::kIdentity; Doc doc; @@ -1194,6 +1257,12 @@ Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TVMScriptPrinter::VisitStmt_(const SparseBufferStoreNode* op) { + Doc doc; + doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + /*! Helper functions for block printing. */ Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { Doc doc; @@ -1374,6 +1443,140 @@ Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { return doc; } +Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) { + Doc doc; + doc << "with " << tir_prefix_ << ".iter(["; + + int n_iter = static_cast(op->sp_iter_vars.size()); + + std::string iter_types = ""; + std::vector sp_iter_docs; + std::vector sp_iter_name_docs; + iter_types.reserve(n_iter); + sp_iter_docs.reserve(n_iter); + sp_iter_name_docs.reserve(n_iter); + + for (int i = 0; i < n_iter; ++i) { + const SpIterVar& sp_iter = op->sp_iter_vars[i]; + const Axis& axis = sp_iter->axis; + Doc iter_doc; + + std::string axis_repr = sp_iter->axis->name; + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")"; + } else if (const FusedAxisNode* fused_axis = axis.as()) { + std::string orig_axis_name = fused_axis->group[fused_axis->index]->name; + if (fused_axis->index == 0) { + iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name; + } else if (fused_axis->index == int(fused_axis->group.size() - 1)) { + iter_doc << orig_axis_name << ")"; + } else { + iter_doc << orig_axis_name; + } + } else { + iter_doc << axis->name; + } + + var_not_in_headers_.insert(sp_iter->var.get()); + sp_iter_docs.push_back(iter_doc); + sp_iter_name_docs.push_back(Print(sp_iter->var)); + iter_types += sp_iter->is_reduction ? "R" : "S"; + } + + doc << PrintSep(sp_iter_docs, Doc::Text(", ")) << "], " << Doc::StrLiteral(iter_types) << ", " + << Doc::StrLiteral(op->name) << ") as [" << PrintSep(sp_iter_name_docs, Doc::Text(", ")) + << "]:"; + + return doc; +} + +Doc TVMScriptPrinter::VisitStmt_(const SparseBlockNode* op) { + Doc doc = PrintOptionalInfo(GetRef(op)); + doc << PrintSparseBlockName(op); + + Doc body; + if (op->init.defined()) { + Doc init; + init << "with " << tir_prefix_ << ".init():"; + init << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value())); + body << init << Doc::NewLine(); + } + body << PrintBody(op->body); + doc << Doc::Indent(4, Doc::NewLine() << body); + + for (const SpIterVar& sp_iter : op->sp_iter_vars) { + TryDeallocVar(sp_iter->var); + } + return doc; +} + +Doc TVMScriptPrinter::PrintSparseStructDefinitions(const SparseBlockNode* sp_block) { + std::vector axis_docs; + std::vector sp_buf_docs; + + for (const ObjectRef& obj : sp_block->sp_structs) { + Array params = sp_block->sp_struct_param_map.Get(obj).value(); + + Doc doc; + doc << Print(obj) << " = " << tir_prefix_ << "."; + + if (const auto* sp_buffer = obj.as()) { + ICHECK_EQ(params.size(), 1); + Doc axes_doc; + if (sp_buffer->axes.size() != 1) { + std::vector axes_docs; + axes_docs.reserve(sp_buffer->axes.size()); + for (const Axis& axis : sp_buffer->axes) { + axes_docs.push_back(Print(axis)); + } + axes_doc << PrintSep(axes_docs, Doc::Text(", ")); + } else { + axes_doc << Print(sp_buffer->axes[0]) << ","; + } + + doc << "match_sparse_buffer(" << Print(params[0]) << ", (" << axes_doc << "), " + << PrintDType(sp_buffer->data->dtype) << ")"; + sp_buf_docs.push_back(doc); + continue; + } + + if (const auto* df_axis = obj.as()) { + ICHECK_EQ(params.size(), 0); + doc << "dense_fixed(" << Print(df_axis->length) << ")"; + } else if (const auto* dv_axis = obj.as()) { + if (const auto* attached_axis = obj.as()) { + ICHECK_EQ(params.size(), 1); + doc << "attach_axis(" << attached_axis->parent_->name << ", " << attached_axis->orig_->name + << ", " << Print(attached_axis->GetNNZ()) << ", " << Print(params[0]) << ", " + << PrintDType(attached_axis->indptr->dtype) << ")"; + } else { + ICHECK_EQ(params.size(), 1); + doc << "dense_variable(" << dv_axis->parent_->name << ", (" << Print(dv_axis->length) + << ", " << Print(dv_axis->GetNNZ()) << "), " << Print(params[0]) << ", " + << PrintDType(dv_axis->indptr->dtype) << ")"; + } + } else if (const auto* sf_axis = obj.as()) { + ICHECK_EQ(params.size(), 1); + doc << "sparse_fixed(" << sf_axis->parent_->name << ", (" << Print(sf_axis->length) << ", " + << Print(sf_axis->nnz_cols) << "), " << Print(params[0]) << ", " + << PrintDType(sf_axis->indices->dtype) << ")"; + } else if (const auto* sv_axis = obj.as()) { + ICHECK_EQ(params.size(), 2); + doc << "sparse_variable(" << sv_axis->parent_->name << ", (" << Print(sv_axis->length) << ", " + << Print(sv_axis->GetNNZ()) << "), (" << Print(params[0]) << ", " << Print(params[1]) + << "), " << PrintDType(sv_axis->indptr->dtype) << ")"; + } else { + ICHECK(false) << "Cannot reach here"; + } + axis_docs.push_back(doc); + } + + Doc res; + res << PrintSep(axis_docs, Doc::NewLine()) << Doc::NewLine() + << PrintSep(sp_buf_docs, Doc::NewLine()) << Doc::NewLine(); + return res; +} + Doc TVMScriptPrinter::PrintBody(const Stmt& body) { int memo_num_child, memo_current_num; std::swap(memo_num_child, num_child_); @@ -1428,6 +1631,7 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { memo_var_.clear(); memo_buf_.clear(); memo_buf_decl_.clear(); + memo_sp_buf_.clear(); var_not_in_headers_.clear(); buf_not_in_headers_.clear(); // print signature @@ -1468,6 +1672,10 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { body << Print((*it).first) << ", " << memo_buf_decl_[buf]; body << ")" << Doc::NewLine(); } + // print sparse data structure definitions + if (const auto* sp_block = op->body.as()) { + body << PrintSparseStructDefinitions(sp_block); + } // print body body << "# body" << Doc::NewLine(); if (op->body->IsInstance() && @@ -1600,6 +1808,16 @@ Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(Var buffer_var, Stmt body return decls; } +Doc TVMScriptPrinter::PrintSparseBuffer(const SparseBufferNode* op) { + const SparseBuffer& buffer = GetRef(op); + return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocSparseBuf(buffer); +} + +Doc TVMScriptPrinter::PrintSpAxis(const AxisNode* op) { + const Axis& axis = GetRef(op); + return meta_.InMeta(axis) ? meta_.GetMetaNode(axis) : AllocAxis(axis); +} + Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; if (op->region.size() == 0) { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 984f8a13351e..415a2f46ea32 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -33,6 +33,7 @@ #include #include "literal/cuda_half_t.h" +#include "literal/cuda_binary_search.h" #include "ptx_mma.h" namespace tvm { @@ -133,6 +134,10 @@ std::string CodeGenCUDA::Finish() { decl_stream << "#include \n"; } + if (need_binary_search_) { + decl_stream << _cuda_binary_search_def; + } + decl_stream << "\n#ifdef _WIN32\n"; decl_stream << " using uint = unsigned int;\n"; decl_stream << " using uchar = unsigned char;\n"; @@ -756,6 +761,37 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, saturate); this->stream << asm_code; + } else if (op->op.same_as(builtin::tvm_lower_bound())) { + need_binary_search_ = true; + os << "__lower_bound("; + ICHECK_EQ(op->args.size(), 4U); + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_upper_bound())) { + need_binary_search_ = true; + os << "__upper_bound("; + ICHECK_EQ(op->args.size(), 4U); + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ", "; + this->PrintExpr(op->args[2], os); + os << ", "; + this->PrintExpr(op->args[3], os); + os << ")"; + } else if (op->op.same_as(builtin::tvm_atomic_add())) { + os << "atomicAdd("; + ICHECK_EQ(op->args.size(), 2U); + this->PrintExpr(op->args[0], os); + os << ", "; + this->PrintExpr(op->args[1], os); + os << ")"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 385b7343c8fd..18ad850e7cd6 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -99,6 +99,8 @@ class CodeGenCUDA final : public CodeGenC { bool need_math_constants_h_{false}; // whether need mma.h bool need_mma_h_{false}; + // whether need binary search + bool need_binary_search_{false}; // Op attribute map OpAttrMap op_need_warp_shuffle_ = Op::GetAttrMap("cuda.need_warp_shuffle"); diff --git a/src/target/source/literal/cuda_binary_search.h b/src/target/source/literal/cuda_binary_search.h new file mode 100644 index 000000000000..2c7a2b6a770f --- /dev/null +++ b/src/target/source/literal/cuda_binary_search.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file cuda_binary_search.h + * \brief Binary search function definition for cuda codegen. + */ +#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ +#define TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ + +static constexpr const char* _cuda_binary_search_def = R"( +template +__forceinline__ __device__ int __lower_bound( + const DType* __restrict__ arr, + DType val, + int l, + int r) { + int low = l - 1, high = r; + /* loop invariant: low < mid < high, arr[low] < val, arr[high] >= val */ + while (low + 1 < high) { + int mid = (low + high) >> 1; + if (arr[mid] < val) { + low = mid; + } else { + high = mid; + } + } + // high = low + 1, arr[low] < val, arr[high] >= val + return high; +} + +template +__forceinline__ __device__ int __upper_bound( + const DType* __restrict__ arr, + DType val, + int l, + int r) { + int low = l - 1, high = r; + /* loop invariant: low < mid < high, arr[low] < val, arr[high] > val */ + while (low + 1 < high) { + int mid = (low + high) >> 1; + if (arr[mid] > val) { + high = mid; + } else { + low = mid; + } + } + // high = low + 1, arr[low] <= val, arr[high] > val + return high; +} +)"; + +#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_BINARY_SEARCH_H_ diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index fbbd4a9522eb..4af178d43848 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1085,6 +1085,36 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "]"; }); +// SparseBufferLoad +SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array indices, Span span) { + ObjectPtr node = make_object(); + node->dtype = buffer->data->dtype; + node->buffer = std::move(buffer); + node->indices = std::move(indices); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBufferLoad") + .set_body_typed([](SparseBuffer buffer, Array indices, Span span) { + return SparseBufferLoad(buffer, indices, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBufferLoadNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) { + p->stream << ", "; + } + } + p->stream << "]"; + }); + // ProducerLoad ProducerLoad::ProducerLoad(DataProducer producer, Array indices, Span span) { ObjectPtr node = make_object(); diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 4c5ea5bfd2d0..b7e0665cf9fd 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -43,6 +43,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void ExprVisitor::VisitExpr_(const SparseBufferLoadNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } @@ -146,6 +150,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { } } +PrimExpr ExprMutator::VisitExpr_(const SparseBufferLoadNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array indices = MutateArray(op->indices, fmutate); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + return SparseBufferLoad(op->buffer, indices); + } +}; + PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array indices = MutateArray(op->indices, fmutate); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc new file mode 100644 index 000000000000..d68acceb473c --- /dev/null +++ b/src/tir/ir/sparse.cc @@ -0,0 +1,461 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file sparse.cc + * \brief buffers and formats in sparse tir. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +/******** Attributes of sparse axis. ********/ + +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisName").set_body_typed([](Axis axis) { + return axis->GetName(); +}); + +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisLength").set_body_typed([](Axis axis) { + return axis->GetLength(); +}); + +TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) { + return DLDataType2String(axis->GetIndexType()); +}); + +TVM_REGISTER_GLOBAL("tir.sparse.GetNNZ").set_body_typed([](Axis axis) { return axis->GetNNZ(); }); + +TVM_REGISTER_GLOBAL("tir.sparse.GetParent").set_body_typed([](Axis axis) { return axis->GetParentAxis(); }); + +/******** AxisNode ********/ + +std::tuple AxisNode::GetOffsetExtent(SparseCtx* ctx) const { + auto prev = ctx->GetPrevAxis(GetRef(this)); + if (prev.defined()) { + Axis prev_axis = prev.value(); + PrimExpr lb = Aggregate(ctx, 0); + PrimExpr orig_prev_coordinate = ctx->GetCoordinate(prev_axis), + orig_prev_offset = ctx->GetOffset(prev_axis); + ctx->SetCoordinate(prev_axis, orig_prev_coordinate + 1); + ctx->SetOffset(prev_axis, orig_prev_offset + 1); + PrimExpr ub = Aggregate(ctx, 0); + ctx->SetCoordinate(prev_axis, orig_prev_coordinate); + ctx->SetOffset(prev_axis, orig_prev_offset); + return {lb, ub}; + } else { + return {Integer(0), GetNNZ()}; + } +}; + +/******** DenseFixedAxis ********/ + +/*! \brief Default constructor of DenseFixedAxis */ +DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->length = std::move(length); + data_ = std::move(node); +} + +PrimExpr DenseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { + auto try_prev = ctx->GetPrevAxis(GetRef(this)); + if (try_prev.defined()) { + Axis prev_axis = try_prev.value(); + PrimExpr prev_offset = ctx->GetOffset(prev_axis); + return prev_offset * length + std::move(index); + } else { + return index; + } +} + +PrimExpr DenseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { + return coordinate; +} + +PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return index; +} + +TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(std::move(name), std::move(length)); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_fixed(" << op->name << ", " << op->length << ")"; + }); + +/******** DenseFromSparseAxis ********/ + +/*! \brief Default constructor of DenseFromSparseAxis */ +DenseFromSparseAxis::DenseFromSparseAxis(SparseAxis base) { + ObjectPtr node = make_object(); + node->name = base->name + "_dense"; + node->length = base->length; + node->base = std::move(base); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(DenseFromSparseAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseFromSparseAxis").set_body_typed([](SparseAxis base) { + return DenseFromSparseAxis(std::move(base)); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_from_sparse(" << op->base->name << ")"; + }); + +/******** FusedAxis ********/ + +/*! \brief Default constructor of FusedAxis */ +FusedAxis::FusedAxis(Array group, int index) { + CHECK(index < int(group.size())) << "Index " << index << "exceeds the size of fused axes group."; + + // TODO(zihao): check whether it valid to fuse axes in the group. + + ObjectPtr node = make_object(); + std::string fused_name = group[0]->name; + for (size_t i = 1; i < group.size(); ++i) { + fused_name += group[i]->name; + } + node->name = "fused_" + fused_name + "_" + group[index]->name; + node->length = group[index]->GetNNZ(); + node->group = std::move(group); + node->index = index; + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(FusedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.FusedAxis").set_body_typed([](Array group, int index) { + return FusedAxis(std::move(group), index); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "fused("; + bool first = true; + for (auto&& orig_axis : op->group) { + if (first) { + first = false; + } else { + p->stream << ", "; + } + p->stream << orig_axis->name; + } + p->stream << ")"; + }); + +/******** DenseVariableAxis ********/ + +/*! \brief Default constructor of DenseVariableAxis */ +DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr nnz, + Buffer indptr) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->length = std::move(length); + node->nnz_ = std::move(nnz); + node->indptr = std::move(indptr); + data_ = std::move(node); +} + +PrimExpr DenseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { + Axis prev_axis = ctx->GetPrevAxis(GetRef(this)).value(); + PrimExpr prev_offset = ctx->GetOffset(prev_axis); + return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index); +} + +PrimExpr DenseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { + return coordinate; +} + +PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return index; +} + +TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") + .set_body_typed([](String name, Axis parent, PrimExpr length, PrimExpr nnz, Buffer indptr) { + return DenseVariableAxis(std::move(name), std::move(parent), std::move(length), + std::move(nnz), std::move(indptr)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name + << ")"; + }); + +/******** AttachedAxis ********/ +/*! \brief Default constructor of AttachedAxis */ +AttachedAxis::AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->orig_ = std::move(orig); + node->length = node->orig_->length; + node->nnz_ = std::move(nnz); + node->indptr = std::move(indptr); + data_ = std::move(node); +} + +PrimExpr AttachedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { + PrimExpr root_offset = ctx->GetOffset(orig_->parent_); + PrimExpr accum_offset = BufferLoad(indptr, {root_offset}); + Array collect_axes; + Array collect_coordinates; + Array strides; + Axis axis; + PrimExpr stride = Integer(1); + for (axis = GetRef(this); axis->kind() == AxisKind::kDenseVariable; + axis = ctx->GetPrevAxis(axis).value()) { + DenseVariableAxis dv_axis = Downcast(axis); + collect_axes.push_back(dv_axis); + collect_coordinates.push_back(ctx->GetCoordinate(axis)); + Buffer indptr; + if (auto att_axis = dv_axis.as()) { + indptr = att_axis->orig_->indptr; + } else { + indptr = dv_axis->indptr; + } + strides.push_back(stride); + stride = stride * (BufferLoad(indptr, {root_offset + 1}) - BufferLoad(indptr, {root_offset})); + } + ICHECK(axis == orig_->parent_) << "Root axis mismatch."; + PrimExpr length = Integer(0); + for (int i = collect_axes.size() - 1; i >= 0; --i) { + DenseVariableAxis axis = std::move(collect_axes[i]); + PrimExpr coordinate = std::move(collect_coordinates[i]); + PrimExpr stride = std::move(strides[i]); + accum_offset = accum_offset + coordinate * stride; + } + return accum_offset; +} + +TVM_REGISTER_NODE_TYPE(AttachedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis") + .set_body_typed([](String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr) { + return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz), + std::move(indptr)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "attached_axis(" << op->name << ", " << op->length << ", " << op->indptr->name + << ")"; + }); + +/******** SparseFixedAxis ********/ + +/*! \brief Default constructor of SparseFixedAxis */ +SparseFixedAxis::SparseFixedAxis(String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->length = std::move(length); + node->indices = std::move(indices); + node->nnz_cols = std::move(nnz_cols); + data_ = std::move(node); +} + +PrimExpr SparseFixedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { + Axis prev_axis = ctx->GetPrevAxis(GetRef(this)).value(); + PrimExpr prev_offset = ctx->GetOffset(prev_axis); + return std::move(prev_offset) * nnz_cols + std::move(index); +} + +PrimExpr SparseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { + PrimExpr lb, ub; + std::tie(lb, ub) = GetOffsetExtent(ctx); + return lower_bound(indices->data, coordinate, lb, ub) - lb; +} + +PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return BufferLoad(indices, {offset}); +} + +TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") + .set_body_typed([](String name, Axis parent, PrimExpr length, Buffer indices, + PrimExpr nnz_cols) { + return SparseFixedAxis(std::move(name), std::move(parent), std::move(length), + std::move(indices), std::move(nnz_cols)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_fixed(" << op->name << ", " << op->parent_->name << ", " << op->length + << ", " << op->nnz_cols << ", " << op->indices->name << ")"; + }); + +/******** SparseVariableAxis ********/ + +/*! \brief Default constructor of SparseVariableAxis */ +SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length, Buffer indptr, + Buffer indices) { + ObjectPtr node = make_object(); + node->name = std::move(name); + node->parent_ = std::move(parent); + node->length = std::move(length); + node->indptr = std::move(indptr); + node->indices = std::move(indices); + data_ = std::move(node); +} + +PrimExpr SparseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { + Axis prev_axis = ctx->GetPrevAxis(GetRef(this)).value(); + PrimExpr prev_offset = ctx->GetOffset(prev_axis); + return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index); +} + +PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { + PrimExpr lb, ub; + std::tie(lb, ub) = GetOffsetExtent(ctx); + return lower_bound(indices->data, coordinate, lb, ub) - lb; +} + +PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return BufferLoad(indices, {offset}); +} + +TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") + .set_body_typed([](String name, Axis parent, PrimExpr length, Buffer indptr, Buffer indices) { + return SparseVariableAxis(std::move(name), std::move(parent), std::move(length), + std::move(indptr), std::move(indices)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_variable(" << op->name << ", " << op->length << ", " << op->indptr->name + << ", " << op->indices->name << ")"; + }); + +/******** SparseBuffer ********/ + +/*! \brief Default constructor of SparseBuffer */ +SparseBuffer::SparseBuffer(Array axes, Buffer data, String name) { + ObjectPtr node = make_object(); + CHECK_GT(static_cast(axes.size()), 0) + << "ValueError: A SparseBuffer should have at least one dimension"; + node->axes = std::move(axes); + node->data = std::move(data); + node->name = std::move(name); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SparseBufferNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") + .set_body_typed([](Array axes, Buffer data, String name) { + return SparseBuffer(std::move(axes), std::move(data), std::move(name)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_buffer(" << op->name << ", ["; + for (int i = 0, n = static_cast(op->axes.size()); i < n; ++i) { + const Axis& axis = op->axes[i]; + p->stream << axis; + if (i < n - 1) { + p->stream << ", "; + } + } + p->stream << "], " << op->data << ")"; + }); + +/******** AxisKind ********/ + +/*! \brief Printer function of Axiskind. */ +std::ostream& operator<<(std::ostream& out, AxisKind type) { + switch (type) { + case AxisKind::kDenseFixed: + out << "dense-fixed"; + break; + case AxisKind::kDenseVariable: + out << "dense-variable"; + break; + case AxisKind::kSparseFixed: + out << "sparse-fixed"; + break; + case AxisKind::kSparseVariable: + out << "sparse-variable"; + break; + default: + LOG(FATAL) << "Cannot reach here"; + } + return out; +} + +/******** SpIterVar ********/ + +/*! \brief Default constructor of SpIterVar. */ +SpIterVar::SpIterVar(Var var, PrimExpr max_extent, bool is_reduction, Axis axis) { + ObjectPtr node = make_object(); + + arith::Analyzer ana; + + node->var = Var(std::move(var)); + node->max_extent = std::move(max_extent); + node->is_reduction = is_reduction; + node->axis = std::move(axis); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SpIterVarNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") + .set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) { + return SpIterVar(std::move(var), std::move(max_extent), is_reduction, std::move(axis)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sp_iter_var(" << op->var->name_hint << ", " << op->max_extent << ", " + << (op->is_reduction ? "reduction" : "spatial") << ", " << op->axis->name << ")"; + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1269607fd334..f714c3748e4c 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -698,6 +698,39 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << '\n'; }); +// SparseBufferStore +SparseBufferStore::SparseBufferStore(SparseBuffer buffer, PrimExpr value, Array indices, + Span span) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->value = std::move(value); + node->indices = std::move(indices); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBufferStore") + .set_body_typed([](SparseBuffer buffer, PrimExpr value, Array indices, Span span) { + return SparseBufferStore(buffer, value, indices, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBufferStoreNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << "["; + for (size_t i = 0; i < op->indices.size(); ++i) { + p->Print(op->indices[i]); + if (i < op->indices.size() - 1) p->stream << ", "; + } + p->stream << "]"; + p->stream << " = "; + p->Print(op->value); + p->stream << '\n'; + }); + // BufferRealize BufferRealize::BufferRealize(Buffer buffer, Array bounds, PrimExpr condition, Stmt body, Span span) { @@ -923,17 +956,21 @@ void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) { } } -void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { - // Print init - if (op->init.defined()) { +void PrintInitStmt(const Optional& init, ReprPrinter* p) { + if (init.defined()) { p->PrintIndent(); p->stream << "with init() {\n"; p->indent += 2; - p->Print(op->init.value()); + p->Print(init.value()); p->indent -= 2; p->PrintIndent(); p->stream << "}\n"; } +} + +void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { + // Print init + PrintInitStmt(op->init, p); // Print body p->Print(op->body); } @@ -1011,6 +1048,92 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +SparseBlock::SparseBlock(Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, + Optional init, Span span) { + CHECK_EQ(sp_structs.size(), sp_struct_params.size()) + << "ValueError: The length of `sp_struct_params` is expected to be equal to the length " + "`sp_structs`, which is the number of sparse data structures"; + Map> sp_struct_param_map; + for (int i = 0; i < static_cast(sp_structs.size()); ++i) { + ObjectRef obj = sp_structs[i]; + Array params = sp_struct_params[i]; + + if (obj->IsInstance()) { + CHECK(params.size() == 0) + << "ValueError: The number of function parameters for dense-fixed axes should be 0"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for dense-variable axes should be 1"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for sparse-fixed axes should be 1"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 2) + << "ValueError: The number of function parameters for sparse-variable axes should be 2"; + } else if (obj->IsInstance()) { + CHECK(params.size() == 1) + << "ValueError: The number of function parameters for SparseBuffer should be 1"; + } else { + LOG(FATAL) << "ValueError: " << obj->_type_key << " is not a sparse data structure"; + } + + sp_struct_param_map.Set(obj, params); + } + + ObjectPtr node = make_object(); + node->sp_iter_vars = std::move(sp_iter_vars); + node->sp_structs = std::move(sp_structs); + node->sp_struct_param_map = std::move(sp_struct_param_map); + node->name = std::move(name); + node->body = std::move(body); + node->init = std::move(init); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.SparseBlock") + .set_body_typed([](Array sp_iter_vars, Array sp_structs, + Array> sp_struct_params, String name, Stmt body, + Optional init, Span span) { + return SparseBlock(sp_iter_vars, sp_structs, sp_struct_params, name, body, init, span); + }); + +TVM_REGISTER_NODE_TYPE(SparseBlockNode); + +void PrintSparseBlockTitle(const SparseBlockNode* op, ReprPrinter* p) { + p->stream << "sparse_block " << op->name << "("; + for (int i = 0; i < static_cast(op->sp_iter_vars.size()); ++i) { + p->Print(op->sp_iter_vars[i]); + if (i < static_cast(op->sp_iter_vars.size()) - 1) { + p->stream << ", "; + } + } + p->stream << ")"; +} + +void PrintSparseBlockBody(const SparseBlockNode* op, ReprPrinter* p) { + // Print init + PrintInitStmt(op->init, p); + // Print body + p->Print(op->body); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + PrintSparseBlockTitle(op, p); + p->stream << " {\n"; + p->indent += 2; + + PrintSparseBlockBody(op, p); + + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); return tir::Call(dtype, op, {}, span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 949e8a1312aa..5e3985c8ce60 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -74,6 +74,11 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void StmtVisitor::VisitStmt_(const SparseBufferStoreNode* op) { + this->VisitExpr(op->value); + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); @@ -153,6 +158,13 @@ void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { this->VisitStmt(op->block); } +void StmtVisitor::VisitStmt_(const SparseBlockNode* op) { + if (op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); +} + class StmtMutator::Internal { public: /*! @@ -386,6 +398,20 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { } } +Stmt StmtMutator::VisitStmt_(const SparseBufferStoreNode* op) { + PrimExpr value = this->VisitExpr(op->value); + Array indices = Internal::Mutate(this, op->indices); + + if (value.same_as(op->value) && indices.same_as(op->indices)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->indices = std::move(indices); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); PrimExpr condition = this->VisitExpr(op->condition); @@ -563,6 +589,23 @@ Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { } } +Stmt StmtMutator::VisitStmt_(const SparseBlockNode* op) { + Optional init = NullOpt; + if (op->init.defined()) { + init = VisitStmt(op->init.value()); + } + Stmt body = VisitStmt(op->body); + + if (init.same_as(op->init) && body.same_as(op->body)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->init = std::move(init); + n->body = std::move(body); + return Stmt(n); + } +} + // Implementations of IRTransform, PostOrderVisit and Substitute class IRApplyVisit : public StmtExprVisitor { public: diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 0e767ead4e6b..33b456b85e9d 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -222,6 +222,15 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce) TIR_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); +TIR_DEFINE_BUILTIN_FUNC(tvm_lower_bound) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_upper_bound) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_BUILTIN_FUNC(tvm_atomic_add) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + TIR_DEFINE_BUILTIN_FUNC(tvm_mma_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 696d82be721f..582474fa90cf 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -829,6 +829,21 @@ PrimExpr nearbyint(PrimExpr x, Span span) { TIR_REGISTER_PURE_UNARY_OP("tir.nearbyint"); +// lower_bound +PrimExpr lower_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { + return tir::Call({kDLInt, 32, 1}, builtin::tvm_lower_bound(), {arr, val, l, r}, span); +} + +// upper_bound +PrimExpr upper_bound(tir::Var arr, PrimExpr val, PrimExpr l, PrimExpr r, Span span) { + return tir::Call({kDLInt, 32, 1}, builtin::tvm_upper_bound(), {arr, val, l, r}, span); +} + +// atomic_add +PrimExpr atomic_add(tir::Var ptr, PrimExpr val, Span span) { + return tir::Call(val->dtype, builtin::tvm_atomic_add(), {ptr, val}, span); +} + // trunc PrimExpr trunc(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint()) { @@ -943,6 +958,12 @@ TVM_REGISTER_GLOBAL("tir.trunc").set_body_typed(tvm::trunc); TVM_REGISTER_GLOBAL("tir._cast").set_body_typed(tvm::cast); +TVM_REGISTER_GLOBAL("tir.lower_bound").set_body_typed(tvm::lower_bound); + +TVM_REGISTER_GLOBAL("tir.upper_bound").set_body_typed(tvm::upper_bound); + +TVM_REGISTER_GLOBAL("tir.atomic_add").set_body_typed(tvm::atomic_add); + // operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_GLOBAL("tir." #Node).set_body_typed([](PrimExpr a, PrimExpr b, Span span) { \ diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 9c6d1e6e96da..3f1e4115b033 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -63,6 +63,17 @@ void VerifyCachedFlags(const ScheduleState& self); const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, GlobalVar* result_g_var); +/*! + * \brief Get PrimFunc and GlobalVar that the sparse block belongs to + * \param mod The IRModule + * \param sp_block The sparse block inside the PrimFunc to be queried + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the sparse block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write + */ +const PrimFuncNode* GetPrimFuncFromSparseBlock(const IRModule& mod, const SparseBlockNode* sp_block, + GlobalVar* result_g_var); + /*! * \brief Get the root node of the sref tree, which is the root block of the PrimFunc. * \param sref The given sref. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index c7ed67187793..9445cc29e82e 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -45,6 +45,26 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl throw; } +const PrimFuncNode* GetPrimFuncFromSparseBlock(const IRModule& mod, const SparseBlockNode* sp_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (func->body.get() == sp_block) { + if (result_g_var != nullptr) { + *result_g_var = g_var; + } + return func; + } + } + } + LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " + "sparse block:\n" + << GetRef(sp_block); + throw; +} + /******** Scope ********/ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 394f0f26db35..1cde535adadc 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -687,5 +687,55 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann /******** Schedule: Misc ********/ +/******** Schedule: SparseTIR schedules ********/ + +SparseBlockRV ConcreteScheduleNode::GetSparseBlock(const String& name, const String& func_name) { + class NotFoundResult : public ScheduleError { + public: + explicit NotFoundResult(String name, IRModule mod) : name_(name), mod_(mod) {} + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + String DetailRenderTemplate() const final { + return "Cannot find a sparse block with the name: " + name_; + } + + String FastErrorString() const final { + return "ScheduleError: Cannot find a sparse block with the specified name"; + } + + String name_; + IRModule mod_; + }; + + BaseFunc func = this->state_->mod->Lookup(func_name); + const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); + + // Currently we only handle cases with single sparse block. + const auto* block = prim_func->body.as(); + if (block == nullptr) { + TVM_TIR_SCHEDULE_BEGIN(); + throw NotFoundResult(name, this->state_->mod); + TVM_TIR_SCHEDULE_END("get-sparse-block", this->error_render_level_); + } + + return CreateRV(GetRef(block)); +} + +Array ConcreteScheduleNode::GetSpIters(const SparseBlockRV& block_rv) { + return this->Get(block_rv)->sp_iter_vars; +} + +void ConcreteScheduleNode::SparseReorder(const SparseBlockRV& block_rv, + const Array& new_order) { + SparseBlock old_block = this->Get(block_rv); + SparseBlock new_block{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + new_block = tir::SparseReorder(state_, old_block, new_order); + TVM_TIR_SCHEDULE_END("sparse-reorder", this->error_render_level_); + this->UpdateRV(block_rv, new_block); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index f0f25ecafa3a..f14a28663bc6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -70,6 +70,7 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block Get(const BlockRV& block_rv) const final; inline For Get(const LoopRV& loop_rv) const final; inline PrimExpr Get(const ExprRV& expr_rv) const final; + inline SparseBlock Get(const SparseBlockRV& sp_block_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; inline bool HasBlock(const BlockRV& block_rv) const final; @@ -78,6 +79,7 @@ class ConcreteScheduleNode : public ScheduleNode { void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } + void RemoveRV(const SparseBlockRV& sp_block_rv) final { RemoveFromSymbolTable(sp_block_rv); } using ScheduleNode::GetSRef; public: @@ -134,6 +136,10 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Misc ********/ void EnterPostproc() override {} + /******** Schedule: SparseTIR schedules ********/ + SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") override; + Array GetSpIters(const SparseBlockRV& block_rv) override; + void SparseReorder(const SparseBlockRV& block_rv, const Array& new_order) override; protected: /******** Utility functions ********/ @@ -171,6 +177,18 @@ class ConcreteScheduleNode : public ScheduleNode { * \return The new random variables created */ inline Array CreateRV(const std::vector& value); + /*! + * \brief Add a sparse block as a random variable into the symbol table + * \param sp_block + * \return SparseBlockRV + */ + inline SparseBlockRV CreateRV(const SparseBlock& sp_block); + /*! + * \brief Update the value of the input SparseBlockRV to the input block. + * \param sp_block_rv The random variable to be updated + * \param block The new value of the random variable + */ + inline void UpdateRV(const SparseBlockRV& sp_block_rv, const SparseBlock& block); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); /*! @@ -211,6 +229,13 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } +inline SparseBlock ConcreteScheduleNode::Get(const SparseBlockRV& sp_block_rv) const { + auto it = this->symbol_table_.find(sp_block_rv); + CHECK(it != this->symbol_table_.end()) + << "IndexError: Cannot find corresponding SparseBlockRV: " << sp_block_rv; + return Downcast((*it).second); +} + inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { @@ -320,6 +345,16 @@ inline Array ConcreteScheduleNode::CreateRV(const std::vector& return results; } +inline SparseBlockRV ConcreteScheduleNode::CreateRV(const SparseBlock& block) { + SparseBlockRV rv; + this->symbol_table_.Set(rv, block); + return rv; +} + +inline void ConcreteScheduleNode::UpdateRV(const SparseBlockRV& rv, const SparseBlock& block) { + this->symbol_table_.Set(rv, block); +} + inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0cd2d3e6f38a..2b80c956241d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -416,6 +416,20 @@ TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& an TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); /******** Schedule: Misc ********/ +/******** Schedule: SparseTIR schedules ********/ + +/*! + * \brief Reorder a list of sparse iterators. It requires the new order to not break the iterator + * dependency. + * \param self The state of the schedule + * \param block The block to be transformed + * \param new_order The new order of the sparse iterators, whose length should equal to the number + * of the input block's sparse iterators + * \return The new sparse block, which is only used to update the corresponding random variable in + * concrete schedule. + */ +TVM_DLL SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block, + const Array& new_order); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/sparse_loop_transformation.cc b/src/tir/schedule/primitive/sparse_loop_transformation.cc new file mode 100644 index 000000000000..aa4c58370eed --- /dev/null +++ b/src/tir/schedule/primitive/sparse_loop_transformation.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the new iterators are valid. We say they are valid if the new order is a + * permutation of the old order + * \param new_order The new iterator order to be checked + * \param old_order The old order of the iterators + * \throw ScheduleError If the iterators in the new order are not valid + */ +void CheckValidInputIterators(const ScheduleState self, const Array& new_order, + const Array& old_order) { + class LengthNotEqualError : public ScheduleError { + public: + explicit LengthNotEqualError(IRModule mod, Array old_order, + Array new_order) + : mod_(std::move(mod)), old_order_(std::move(old_order)), new_order_(std::move(new_order)) { + ICHECK_NE(new_order_.size(), old_order_.size()); + } + + String FastErrorString() const final { + return "ScheduleError: The number of iterators in the new order does not equal to the " + "number of iterators in the old order"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: The new order has " << new_order_.size() << " iterators" << new_order_ + << ", while the old order has " << old_order_.size() << " iterators" << old_order_ + << ". They are supposed to have the same set of iterators"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + Array old_order_; + Array new_order_; + }; + + class IterNotAppearError : public ScheduleError { + public: + explicit IterNotAppearError(IRModule mod, SpIterVar iter, Array old_order) + : mod_(std::move(mod)), iter_(std::move(iter)), old_order_(std::move(old_order)) {} + + String FastErrorString() const final { + return "ScheduleError: An iterator in the new order does not appear in the old order"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: Iterator " << iter_ + << " appears in the new order. However, it does not appear in the old order " + << old_order_; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + SpIterVar iter_; + Array old_order_; + }; + + if (new_order.size() != old_order.size()) { + throw LengthNotEqualError(self->mod, new_order, old_order); + } + for (const SpIterVar& sp_iter : new_order) { + if (std::find(old_order.begin(), old_order.end(), sp_iter) == old_order.end()) { + throw IterNotAppearError(self->mod, sp_iter, old_order); + } + } +} + +/*! + * \brief Check whether the sparse reorder would break dependency between iterators. + * \param new_order The new iterator order to be checked. + * \throw ScheduleError If the sparse reorder breaks dependency. + */ +void CheckDependency(const ScheduleState self, const Array& new_order) { + class DependencyError : public ScheduleError { + public: + explicit DependencyError(IRModule mod, SpIterVar iter, Array new_order): + mod_(std::move(mod)), iter_(std::move(iter)), new_order_(std::move(new_order)) {} + + String FastErrorString() const final { + return "ScheduleError: the sparse reorder breaks dependency between axes."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: in new order " << new_order_ + << " iterator " << iter_ << " was placed before its dependent iterator."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + SpIterVar iter_; + Array new_order_; + }; + + std::set axes_set; + for (const SpIterVar& sp_iter : new_order) { + Axis axis = sp_iter->axis; + auto try_parent = axis->GetParentAxis(); + if (try_parent.defined()) { + Axis parent = try_parent.value(); + if (axes_set.find(parent) == axes_set.end()) { + throw DependencyError(self->mod, sp_iter, new_order); + } + } + axes_set.insert(axis); + } +} + + +SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block, + const Array& new_order) { + // Step 1. Check whether the iterators in `new_order` are the same as `block`'s iterators. + CheckValidInputIterators(self, new_order, block->sp_iter_vars); + + // Step 2. Check whether the new order does not break the iterator dependency. + CheckDependency(self, new_order); + + // Step 3. Create the new SparseBlock. + ObjectPtr p_new_block = make_object(*block.get()); + p_new_block->sp_iter_vars = new_order; + SparseBlock new_block(p_new_block); + + // Step 4. Create the new IRModule. (The following lines are from Schedule::Replace(...)) + const PrimFuncNode* g_func = nullptr; + GlobalVar g_var; + g_func = GetPrimFuncFromSparseBlock(self->mod, block.get(), &g_var); + + IRModuleNode* new_mod = self->mod.CopyOnWrite(); + MapNode* new_map = new_mod->functions.CopyOnWrite(); + PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); + ICHECK(ref_new_func.get() == g_func); + PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); + + new_func->body = new_block; + new_map->at(g_var) = std::move(ref_new_func); + self->mod = GetRef(new_mod); + + return new_block; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index b466843f9459..018471d5070c 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -27,6 +27,8 @@ BlockRV::BlockRV() { this->data_ = make_object(); } LoopRV::LoopRV() { this->data_ = make_object(); } +SparseBlockRV::SparseBlockRV() { this->data_ = make_object(); } + /**************** GetSRef ****************/ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { @@ -42,6 +44,7 @@ StmtSRef ScheduleNode::GetSRef(const StmtNode* stmt) const { TVM_REGISTER_NODE_TYPE(BlockRVNode); TVM_REGISTER_NODE_TYPE(LoopRVNode); +TVM_REGISTER_NODE_TYPE(SparseBlockRVNode); TVM_REGISTER_OBJECT_TYPE(ScheduleNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetMod") // @@ -61,6 +64,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); +TVM_REGISTER_GLOBAL("tir.schedule.SparseBlockRV").set_body_typed([]() { return SparseBlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, int error_render_level) -> Schedule { @@ -87,6 +91,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGet") if (const auto* expr_rv = obj.as()) { return self->Get(GetRef(expr_rv)); } + if (const auto* sp_block_rv = obj.as()) { + return self->Get(GetRef(sp_block_rv)); + } LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << obj->GetTypeKey() << ". Its value is: " << obj; throw; @@ -116,6 +123,9 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") if (const auto* expr_rv = obj.as()) { return self->RemoveRV(GetRef(expr_rv)); } + if (const auto* sp_block_rv = obj.as()) { + return self->RemoveRV(GetRef(sp_block_rv)); + } LOG(FATAL) << "TypeError: Invalid type: " << obj->GetTypeKey(); throw; }); @@ -229,6 +239,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") /******** (FFI) Misc ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc") .set_body_method(&ScheduleNode::EnterPostproc); +/******** (FFI) SparseTIR schedules ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSparseBlock") + .set_body_method(&ScheduleNode::GetSparseBlock); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetSpIters") + .set_body_method(&ScheduleNode::GetSpIters); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSparseReorder") + .set_body_method(&ScheduleNode::SparseReorder); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 1e2e57eb6eca..96048ae53050 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -438,5 +438,29 @@ void TracedScheduleNode::EnterPostproc() { /*outputs=*/{})); } +/******** Schedule: SparseTIR schedules ********/ +SparseBlockRV TracedScheduleNode::GetSparseBlock(const String& name, const String& func_name) { + SparseBlockRV result = ConcreteScheduleNode::GetSparseBlock(name, func_name); + + // Do not support traced schedule so far. + + return result; +} + +Array TracedScheduleNode::GetSpIters(const SparseBlockRV& block_rv) { + Array result = ConcreteScheduleNode::GetSpIters(block_rv); + + // Do not support traced schedule so far. + + return result; +} + +void TracedScheduleNode::SparseReorder(const SparseBlockRV& block_rv, + const Array& new_order) { + ConcreteScheduleNode::SparseReorder(block_rv, new_order); + + // Do not support traced schedule so far. +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 5d3fdbf570de..6154bdedf4b0 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -97,6 +97,10 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Misc ********/ void EnterPostproc() final; + /******** Schedule: SparseTIR schedules ********/ + SparseBlockRV GetSparseBlock(const String& name, const String& func_name = "main") final; + Array GetSpIters(const SparseBlockRV& block_rv) final; + void SparseReorder(const SparseBlockRV& block, const Array& new_order) final; }; } // namespace tir diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc new file mode 100644 index 000000000000..84525a97745a --- /dev/null +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -0,0 +1,676 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_sparse_tir.cc + */ + +#include +#include +#include +#include + +#include +#include + +#include "../../support/utils.h" +#include "../schedule/analysis.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Add the buffers accessed in sparse blocks to the PrimFunc's buffer map. + * \param f The PrimFunc whose buffer map is to be updated. + * \return The up to date buffer map. + */ +Map UpdateBufferMap(PrimFunc f) { + struct BufferMapUpdater : public StmtVisitor { + explicit BufferMapUpdater(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} + + void VisitStmt_(const SparseBlockNode* sp_block) { + for (const auto& it : sp_block->sp_struct_param_map) { + const ObjectRef& sp_struct = it.first; + const Array& params = it.second; + if (const auto* dv_axis = sp_struct.as()) { + // collect indptr buffer of dense variable axis. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], dv_axis->indptr); + } else if (const auto* sf_axis = sp_struct.as()) { + // collect indices buffer of sparse fixed axis. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], sf_axis->indices); + } else if (const auto* sv_axis = sp_struct.as()) { + // collect indptr and indices buffer of sparse variable axis. + ICHECK_EQ(params.size(), 2); + buffer_map_.Set(params[0], sv_axis->indptr); + buffer_map_.Set(params[1], sv_axis->indices); + } else if (const auto* sp_buffer = sp_struct.as()) { + // collect data buffer for sparse buffers. + ICHECK_EQ(params.size(), 1); + buffer_map_.Set(params[0], sp_buffer->data); + } + } + } + + Map buffer_map_; + }; + + BufferMapUpdater updater(f->buffer_map); + updater(f->body); + return std::move(updater.buffer_map_); +} + +/*! + * \brief Aggregate the offset on previous axes with the index on the current axis. + * \param prev_offset The lowered offset accumulated over all prior axes. + * \param axis The current axis. + * \param index The sparse index on current axis. + * \param ana The analyzer used for expression simplification. + * \return The aggregated offset. + */ +PrimExpr AggregateOffset(SparseCtx* ctx, Axis axis, PrimExpr index, arith::Analyzer* ana) { + PrimExpr new_offset = axis->Aggregate(ctx, index); + if (ana != nullptr) { + return ana->Simplify(new_offset); + } else { + return new_offset; + } +} + +/*! \brief A class storing the context information of sparse blocks. */ +class SparseBlockCtx : public SparseCtx { + private: + struct Scope { + explicit Scope(SparseBlock sp_block) : sp_block(std::move(sp_block)) { + for (const SpIterVar& sp_iter_var : this->sp_block->sp_iter_vars) { + sp_iter_var_map.Set(sp_iter_var->var, sp_iter_var); + } + } + + /*! \brief The sparse block */ + SparseBlock sp_block; + /*! \brief A mapping from the internal variables of sparse iterators to the iterators */ + Map sp_iter_var_map; + /*! \brief The stored offsets of the axis in the sparse block */ + Map cached_offsets; + /*! \brief The stored coordinates of the axis in the sparse block */ + Map cached_coordinates; + }; + + public: + explicit SparseBlockCtx(arith::Analyzer* ana) : ana_(ana) {} + + void EnterScope(const SparseBlockNode* sp_block) { + stack_.emplace_back(GetRef(sp_block)); + /* Compute offsets and coordinates */ + size_t n_iters = sp_block->sp_iter_vars.size(); + for (size_t i = 0; i < n_iters;) { + SpIterVar sp_iter_var = sp_block->sp_iter_vars[i]; + Axis axis = sp_iter_var->axis; + + PrimExpr offset, index; + if (auto fused_axis = axis.as()) { + auto group = fused_axis->group; + offset = sp_block->sp_iter_vars[i + group.size() - 1]->var; + for (int j = group.size() - 1; j >= 0; --j) { + Axis orig = group[j]; + SetOffset(orig, offset); + if (j > 0) { + Buffer indptr; + if (auto sv_axis = orig.as()) { + indptr = sv_axis->indptr; + } else if (auto dv_axis = orig.as()) { + indptr = dv_axis->indptr; + } else { + throw; + } + offset = upper_bound(indptr->data, offset, Integer(0), indptr->shape[0]) - 1; + } + } + for (size_t j = 0; j < group.size(); ++j) { + Axis orig = group[j]; + offset = GetOffset(orig); + PrimExpr lb = std::get<0>(orig->GetOffsetExtent(this)); + index = offset - lb; + PrimExpr coordinate = orig->Decompress(this, offset, index); + SetCoordinate(orig, coordinate); + i++; + } + } else { + offset = AggregateOffset(this, axis, sp_iter_var->var, ana_); + index = sp_iter_var->var; + PrimExpr coordinate = axis->Decompress(this, offset, index); + SetOffset(axis, offset); + SetCoordinate(axis, coordinate); + i++; + } + } + } + + void ExitScope() { stack_.pop_back(); } + + /*! + * \brief Get the sparse iterator corresponding to the given variable in the current scope. + * \param var The variable whose corresponding sparse iterator is to be looked up. + * \return The corresponding sparse iterator of the input variable, or `NullOpt` if the input + * variable does not corresponds to a sparse iterator. + */ + Optional GetSparseIterVar(const VarNode* var) const { + return top()->sp_iter_var_map.Get(GetRef(var)); + } + + Optional GetPrevAxis(Axis axis) const { + // In Sparse block, previous axis is parent axis. + return axis->GetParentAxis(); + } + + void SetOffset(Axis axis, PrimExpr offset) { top()->cached_offsets.Set(axis, offset); } + + void SetCoordinate(Axis axis, PrimExpr idx) { top()->cached_coordinates.Set(axis, idx); } + + /*! + * \brief Get the offset of the input axis in the block. + * \param sp_iter The axis to be queried. + * \return The offset. + */ + PrimExpr GetOffset(Axis axis) const { + Optional try_offset = top()->cached_offsets.Get(axis); + CHECK(try_offset.defined()) << "The offset of axis " << axis->name << " not defined yet."; + PrimExpr offset = try_offset.value(); + return std::move(offset); + } + + /*! + * \brief Get the coordinate of the input axis in the block. + * \param axis The axis to be queried. + * \return The coordinate. + */ + PrimExpr GetCoordinate(Axis axis) const { + Optional try_index = top()->cached_coordinates.Get(axis); + CHECK(try_index.defined()) << "The index of axis not defined yet."; + PrimExpr index = try_index.value(); + return std::move(index); + } + + /*! + * \brief Get the iteration extent of the input sparse iterator. + * \param sp_iter_var The sparse iterator to be queried. + * \return The iteration extent of the input sparse iterator. + */ + PrimExpr GetIterExtent(SpIterVar sp_iter) { + if (const auto* fused_axis = sp_iter->axis.as()) { + // Fused axis. + if (fused_axis->index == int(fused_axis->group.size() - 1)) { + // The last axis in the fused group. + return fused_axis->GetNNZ(); + } else { + return Integer(1); + } + } + PrimExpr lb, ub; + std::tie(lb, ub) = sp_iter->axis->GetOffsetExtent(this); + return ana_->Simplify(ub - lb); + } + + Optional MatchAxis(SparseCtx* buf_ctx, Axis axis) { + if (!top()->cached_offsets.Get(axis).defined()) { + return NullOpt; + } else { + Axis axis_ = axis; + auto prev = buf_ctx->GetPrevAxis(axis); + auto blk_prev = GetPrevAxis(axis); + for (; prev.defined();) { + if (prev != blk_prev) { + return NullOpt; + } else { + axis_ = prev.value(); + prev = buf_ctx->GetPrevAxis(axis_); + blk_prev = GetPrevAxis(axis_); + } + } + return axis; + } + } + + bool MatchIndex(Optional matched_axis, PrimExpr expr) { + if (!matched_axis.defined()) { + return false; + } + auto var = expr.as(); + if (var == nullptr) { + return false; + } + auto try_sp_iter_var = top()->sp_iter_var_map.Get(GetRef(var)); + if (!try_sp_iter_var.defined()) { + return false; + } + Axis axis = try_sp_iter_var.value()->axis; + if (auto fused_axis = axis.as()) { + axis = fused_axis->group[fused_axis->index]; + } + return axis == matched_axis.value(); + } + + private: + std::vector stack_; + arith::Analyzer* ana_; + + /*! \brief The top scope in the sparse block stack. */ + inline Scope* top() const { return const_cast(&stack_.back()); } +}; + +/*! \brief A class storing the context information of sparse buffer accesses. */ +class SparseBufferAccessCtx : public SparseCtx { + private: + struct Scope { + explicit Scope(Array axes) : axes(std::move(axes)) {} + + Array axes; + /*! \brief The stored offsets of the axis in the sparse buffer */ + Map cached_offsets; + /*! \brief The stored coordinates of the axis in the sparse buffer */ + Map cached_coordinates; + PrimExpr final_offset; + }; + + public: + explicit SparseBufferAccessCtx(arith::Analyzer* ana) : ana_(ana) {} + + void EnterScope(SparseBuffer sp_buffer, Array raw_indices_, Array coordinates, + SparseBlockCtx* sp_blk_ctx) { + stack_.emplace_back(sp_buffer->axes); + size_t n_dims = sp_buffer->axes.size(); + ICHECK(n_dims == raw_indices_.size()) + << "The number of indices does not equal number of axes in the sparse buffer."; + ICHECK(n_dims == coordinates.size()) + << "The number of coordinates does not equal number of axes in the sparse buffer."; + + /* Compute offsets and coordinates. */ + for (size_t i = 0; i < n_dims; ++i) { + Axis axis = sp_buffer->axes[i]; + PrimExpr coordinate = coordinates[i]; + SetCoordinate(axis, coordinate); + auto try_parent = axis->GetParentAxis(); + // update axis match + + auto matched_axis = sp_blk_ctx->MatchAxis(this, axis); + // compute offset + PrimExpr offset = (sp_blk_ctx->MatchIndex(matched_axis, raw_indices_[i])) + ? sp_blk_ctx->GetOffset(axis) + : AggregateOffset(this, axis, axis->Compress(this, coordinate), ana_); + SetOffset(axis, offset); + if (i + 1 == n_dims) { + // the final axis; + top()->final_offset = offset; + } + } + } + + void ExitScope() { stack_.pop_back(); } + + Optional GetPrevAxis(Axis axis) const { + Array axes = top()->axes; + Optional ret = NullOpt; + for (auto it : axes) { + if (it == axis) { + break; + } + ret = it; + } + return ret; + } + + void SetOffset(Axis axis, PrimExpr offset) { top()->cached_offsets.Set(axis, offset); } + + void SetCoordinate(Axis axis, PrimExpr coordinate) { + top()->cached_coordinates.Set(axis, coordinate); + } + + PrimExpr GetOffset(Axis axis) const { + auto try_offset = top()->cached_offsets.Get(axis); + CHECK(try_offset.defined()) << "The offset of the axis is not defined."; + PrimExpr offset = try_offset.value(); + return std::move(offset); + } + + PrimExpr GetCoordinate(Axis axis) const { + auto try_coordinate = top()->cached_coordinates.Get(axis); + CHECK(try_coordinate.defined()) << "The coordinate of the axis is not defined."; + PrimExpr coordinate = try_coordinate.value(); + return std::move(coordinate); + } + + PrimExpr GetLastOffset() const { return top()->final_offset; } + + private: + std::vector stack_; + arith::Analyzer* ana_; + + /*! \brief The top scope in the sparse buffer access stack. */ + inline Scope* top() const { return const_cast(&stack_.back()); } +}; + +/*! + * \brief Rewrite the high-dimensional sparse buffers and access indices to low-level buffers and + * offsets. + */ +class IndexTransformer : public StmtExprMutator { + public: + explicit IndexTransformer() : sp_blk_ctx_(&ana_), sp_buf_ctx_(&ana_) {} + + private: + /*! + * \brief Convert the input sparse iterator to a block iterator. + * \param sp_iter The sparse iterator to be converted. + * \param var_map The mapping from sparse iterators to loop variables, for extent substitution. + * \return The corresponding block iterator. + */ + IterVar SpIterVarToIterVar(const SpIterVar& sp_iter, Map var_map) { + // Substitute the iteration vars in the expression with the loop vars. + return IterVar(Range::FromMinExtent(0, sp_blk_ctx_.GetIterExtent(sp_iter)), + sp_iter->var, sp_iter->is_reduction ? kCommReduce : kDataPar); + } + + /*! + * \brief Generate the read and write regions for sparse blocks. + * \param sp_block The sparse block, which is the source of the reads and writes. + * \param reads The read regions of the sparse block. + * \param writes The write regions of the sparse block. + */ + void GenerateReadWriteRegions(const SparseBlockNode* sp_block, Array* reads, + Array* writes) { + for (const ObjectRef& obj : sp_block->sp_structs) { + if (const auto* dv_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(dv_axis->indptr)); + } else if (const auto* sf_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(sf_axis->indices)); + } else if (const auto* sv_axis = obj.as()) { + reads->push_back(BufferRegion::FullRegion(sv_axis->indptr)); + reads->push_back(BufferRegion::FullRegion(sv_axis->indices)); + } else if (const auto* sp_buffer = obj.as()) { + if (buffer_read_.count(sp_buffer)) { + reads->push_back(BufferRegion::FullRegion(sp_buffer->data)); + } + if (buffer_write_.count(sp_buffer)) { + writes->push_back(BufferRegion::FullRegion(sp_buffer->data)); + } + } + } + } + + /*! + * \brief Generated the loop nests for the outside the input body. + * \param body The statement to be wrapped by loop nests. + * \param block_iters The block iterators defined in the outermost block in `body`. + * \param loop_vars The loop variables of the loops to be generated. + * \return The outermost generated loop. + */ + Stmt GenerateLoops(Stmt body, const Array& block_iters, const Array& loop_vars) { + int n_iter = static_cast(block_iters.size()); + for (int i = n_iter - 1; i >= 0; --i) { + const Range& dom = block_iters[i]->dom; + body = For(loop_vars[i], dom->min, dom->extent, ForKind::kSerial, std::move(body)); + } + return body; + } + + PrimExpr VisitExpr_(const VarNode* var) final { + auto try_sp_iter = sp_blk_ctx_.GetSparseIterVar(var); + if (try_sp_iter.defined()) { + SpIterVar sp_iter = try_sp_iter.value(); + Axis axis = sp_iter->axis; + if (auto fused_axis = axis.as()) { + axis = fused_axis->group[fused_axis->index]; + } + return sp_blk_ctx_.GetCoordinate(axis); + } else { + return GetRef(var); + } + } + + PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { + buffer_read_.insert(load->buffer.get()); + Array coordinates; + for (const PrimExpr& index : load->indices) { + coordinates.push_back(VisitExpr(index)); + } + sp_buf_ctx_.EnterScope(load->buffer, load->indices, coordinates, &sp_blk_ctx_); + PrimExpr offset = sp_buf_ctx_.GetLastOffset(); + sp_buf_ctx_.ExitScope(); + return BufferLoad(load->buffer->data, {std::move(offset)}); + } + + Stmt VisitStmt_(const SparseBufferStoreNode* store) final { + buffer_write_.insert(store->buffer.get()); + Array coordinates; + for (const PrimExpr& index : store->indices) { + coordinates.push_back(VisitExpr(index)); + } + sp_buf_ctx_.EnterScope(store->buffer, store->indices, coordinates, &sp_blk_ctx_); + PrimExpr offset = sp_buf_ctx_.GetLastOffset(); + sp_buf_ctx_.ExitScope(); + PrimExpr value = VisitExpr(store->value); + return BufferStore(store->buffer->data, std::move(value), {std::move(offset)}); + } + + Stmt VisitStmt_(const SparseBlockNode* sp_block) final { + /*! \brief A class temporarily storing the block signatures and the outer loop variables of the + * blocks to be generated */ + struct BlockInfo { + /*! \brief The outer loop variables of the block */ + Array loop_vars; + /*! \brief The block iterators of the block */ + Array block_iters; + /*! \brief The block iterator bindings of the block */ + Array iter_bindings; + /*! \brief The init statement of the block */ + Optional init; + + /*! + * \brief Push a new loop variable/block iterator/iterator binding to this block. + * \param loop_var The loop variable to be pushed. + * \param block_iter The block iterator to be pushed. + * \param iter_binding The iterator binding to be pushed. + */ + void Push(const Var& loop_var, const IterVar& block_iter, const PrimExpr& iter_binding) { + loop_vars.push_back(loop_var); + block_iters.push_back(block_iter); + iter_bindings.push_back(iter_binding); + } + + /*! + * \brief Check whether the input loop variable exists in the outer loop variables of this + * block. + * \param target_loop_var The loop variable to be checked + * \return Whether the input loop variable exists in the outer loop variables of this block. + */ + bool LoopVarAppears(const Var& target_loop_var) { + for (const Var& loop_var : loop_vars) { + if (loop_var.same_as(target_loop_var)) { + return true; + } + } + return false; + } + + /*! + * \brief Check whether a new block is needed. We need to create a new block when: + * - the input axis is variable (dense-variable or sparse-variable), and + * - the parent axis of the input axis has corresponding loop variable in the current block. + * \param axis The axis to be checked. + * \param axis2loop_var The mapping from axes to their corresponding loop variables. + * \param defined_loop_vars The loop variables defined in previous blocks + * (excluding the current one). + * \return Whether a new block is needed according to the conditions above. + */ + bool NeedCreateNewBlock(Axis axis, Map axis2loop_var, + const std::unordered_set& defined_loop_vars) { + if (axis->kind() != AxisKind::kDenseVariable && axis->kind() != AxisKind::kSparseVariable) { + return false; + } + + const Optional& loop_var = axis2loop_var.Get(axis->GetParentAxis().value()); + CHECK(loop_var.defined()) << "ValueError: The parent axis of " << axis + << " does not appear in the sparse block"; + + if (LoopVarAppears(loop_var.value())) { + return true; + } + CHECK(defined_loop_vars.count(loop_var.value().get())) + << "ValueError: The parent axis of " << axis + << " should appear before it in the sparse block"; + return false; + } + }; + + int n_iter = static_cast(sp_block->sp_iter_vars.size()); + buffer_read_.clear(); + buffer_write_.clear(); + + // Step 1. Enter a new sparse block scope. + sp_blk_ctx_.EnterScope(sp_block); + + // Step 2. Recursively mutate the `init` field and the block body. + Optional init = + sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional(NullOpt); + Stmt body = VisitStmt(sp_block->body); + + // Step 3. Create the new loop variables. + Map var_map; + Map axis2loop_var; + for (const SpIterVar& sp_iter_var : sp_block->sp_iter_vars) { + Var loop_var("v_" + sp_iter_var->var->name_hint); + var_map.Set(sp_iter_var->var, loop_var); + if (auto fused_axis = sp_iter_var->axis.as()) { + // handle the special case of fused_axis + axis2loop_var.Set(fused_axis->group[fused_axis->index], loop_var); + } + axis2loop_var.Set(sp_iter_var->axis, loop_var); + } + + // Step 4. Gather the information of the blocks to be generated. + std::unordered_set defined_loop_vars; + std::vector block_infos(1); + /* Whether a reduction block iterator has appeared */ + bool has_reduction_var = false; + + for (int i = 0; i < n_iter; ++i) { + SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; + if (block_infos.back().NeedCreateNewBlock(sp_it_var->axis, axis2loop_var, + defined_loop_vars)) { + // Mark the loop variables corresponding to the current block as "defined". + for (const Var& loop_var : block_infos.back().loop_vars) { + defined_loop_vars.insert(loop_var.get()); + } + // Create a new BlockInfo. + block_infos.emplace_back(); + } + + Var loop_var = Downcast(var_map.Get(sp_it_var->var)); + block_infos.back().Push(loop_var, SpIterVarToIterVar(sp_it_var, var_map), loop_var); + if (!has_reduction_var && sp_it_var->is_reduction) { + block_infos.back().init = std::move(init); + has_reduction_var = true; + } + } + + // Step 5. Generate the read-regions and write-retions of the block. + Array reads; + Array writes; + GenerateReadWriteRegions(sp_block, &reads, &writes); + + // Step 6. Generate nested blocks and loops from innermost to outermost. + for (int i = static_cast(block_infos.size()) - 1; i >= 0; --i) { + BlockInfo info = std::move(block_infos[i]); + Block block(/*iter_vars=*/info.block_iters, + /*reads=*/reads, + /*writes=*/writes, + /*name_hint=*/sp_block->name + std::to_string(i), + /*body=*/std::move(body), + /*init=*/std::move(info.init), + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/{{"sparse", Bool(true)}}); + BlockRealize block_realize(/*iter_values=*/std::move(info.iter_bindings), + /*predicate=*/const_true(), + /*block=*/std::move(block)); + Stmt loop = GenerateLoops(std::move(block_realize), std::move(info.block_iters), + std::move(info.loop_vars)); + body = std::move(loop); + } + + // Step 7: Exit the sparse block scope. + sp_blk_ctx_.ExitScope(); + + return body; + } + + SparseBlockCtx sp_blk_ctx_; + SparseBufferAccessCtx sp_buf_ctx_; + std::unordered_set buffer_read_; + std::unordered_set buffer_write_; + arith::Analyzer ana_; +}; + +/*! + * \brief Wrap the body statement with an empty root block. + * \param body The body statements to wrap with. + * \return The wrapped block. + */ +Stmt WrapWithRootBlock(Stmt body) { + Block root_block({}, {}, {}, "root", std::move(body)); + return BlockRealize({}, const_true(), std::move(root_block)); +} + +PrimFunc LowerSparseTIR(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + // Step 1. Update the PrimFunc's buffer map. + fptr->buffer_map = UpdateBufferMap(f); + // Step 2. Lower indices. + fptr->body = IndexTransformer()(std::move(fptr->body)); + // Step 3. Wrap the function body with a root block. + fptr->body = WrapWithRootBlock(std::move(fptr->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +/*! + * \brief The lowering pass from TIR to Sparse TIR. + */ +Pass LowerSparseTIR() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerSparseTIR(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerSparseTIR").set_body_typed(LowerSparseTIR); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/sparsetir/bench_rgcn.py b/tests/python/sparsetir/bench_rgcn.py new file mode 100644 index 000000000000..f3d5c5b25c53 --- /dev/null +++ b/tests/python/sparsetir/bench_rgcn.py @@ -0,0 +1,175 @@ +from dgl.heterograph import DGLHeteroGraph +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +import dgl +import dgl.function as fn +import torch as th +from tvm.script import tir as T +from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset +from lowered_tir import lowered_rgcn_forward +from sparse_tir_scripts import rgcn_forward + + +def get_dataset_by_name(name: str): + if name == "aifb": + return AIFBDataset() + elif name == "mutag": + return MUTAGDataset() + elif name == "bgs": + return BGSDataset() + elif name == "am": + return AMDataset() + else: + raise KeyError("Unknown dataset {}.".format(name)) + + +class TorchOpTimer(object): + def __enter__(self): + self.start_event = th.cuda.Event(enable_timing=True) + self.end_event = th.cuda.Event(enable_timing=True) + self.start_event.record() + return self + + def __exit__(self, type, value, traceback): + self.end_event.record() + th.cuda.synchronize() # Wait for the events to be recorded! + self.time = self.start_event.elapsed_time(self.end_event) + + +def prepare_hetero_graph_simplified(g: dgl.DGLHeteroGraph): + ntype_pointer = np.cumsum([0] + [g.number_of_nodes(ntype) for ntype in g.ntypes]) + + etype_pointer = [0] + for etype in g.canonical_etypes: + g_sub = g[etype] + etype_pointer.append(etype_pointer[-1] + g_sub.num_edges()) + + return { + "ntype_node_pointer": th.IntTensor(ntype_pointer).cuda(), + "etype_edge_pointer": th.IntTensor(etype_pointer).cuda(), + } + + +def test_rgcn(g: DGLHeteroGraph, feat_size: int): + g = g.to(0) + feat = th.rand(g.num_src_nodes(), feat_size).to(0) / 100 + out = th.zeros(g.num_dst_nodes(), feat_size).to(0) / 100 + weight = th.rand(g.num_rels, feat_size, feat_size).to(0) + indptr, indices, eid = g.adj_sparse(fmt="csc") + etype = g.edata[dgl.ETYPE][eid] + + cold_start = 3 + total = 10 + accum = 0 + + # dgl-lowmem + try: + g.srcdata["feat"] = feat.unsqueeze(-1) + us, vs = g.edges() + feat_transformed = feat[us] + msg = th.zeros(g.num_edges(), feat_size).to(0) + for epoch in range(10): + with TorchOpTimer() as timer: + with th.no_grad(): + for i in range(1, len(g.etype_pointer)): + start = g.etype_pointer[i - 1] + end = g.etype_pointer[i] + msg[start:end] = feat_transformed[start:end] @ weight[i - 1] + y_dgl_lowmem = dgl.ops.copy_e_sum(g, msg) + if epoch >= cold_start: + accum += timer.time + print("dgl-lowmem:\t\t {}ms".format(accum / (total - cold_start))) + y_dgl_lowmem = None + except RuntimeError as err: + print("dgl-lowmem: OOM") + y_dgl_lowmem = None + except BaseException as err: + print(err) + raise + + # dgl-bmm + + def msg_func(edges): + h = edges.src["feat"] + W = weight[edges.data[dgl.ETYPE]] + return {"msg": W @ h} + + try: + g.srcdata["feat"] = feat.unsqueeze(-1) + for epoch in range(10): + with TorchOpTimer() as timer: + with th.no_grad(): + g.update_all(msg_func, fn.sum("msg", "y")) + y_dgl = g.dstdata["y"].squeeze(-1) + if epoch >= cold_start: + accum += timer.time + print("dgl-bmm:\t\t {}ms".format(accum / (total - cold_start))) + except RuntimeError as err: + print("dgl-bmm: OOM") + y_dgl = None + except BaseException as err: + print(err) + raise + + # tir + mod = tvm.IRModule.from_expr(rgcn_forward) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn_forward, True) + + N, R, FEAT_SIZE, NNZ = mod["main"].params[-4:] + sch = tir.Schedule( + mod["main"].specialize( + {N: g.number_of_nodes(), R: g.num_rels, FEAT_SIZE: feat_size, NNZ: g.number_of_edges()} + ) + ) + + outer = sch.get_block("rgcn-forward0") + inner = sch.get_block("rgcn-forward1") + i, f_out = sch.get_loops(outer) + j, f_in = sch.get_loops(inner) + sch.bind(i, "blockIdx.x") + sch.bind(f_out, "threadIdx.y") + sch.bind(f_in, "threadIdx.x") + f = tvm.build(sch.mod, target="cuda") + + E = tvm.nd.array(etype.cpu().numpy().astype("int32"), device=tvm.cuda(0)) + W = tvm.nd.array(weight.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + X = tvm.nd.array(feat.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + Y = tvm.nd.array(out.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + indptr = tvm.nd.array(indptr.cpu().numpy().astype("int32"), device=tvm.cuda(0)) + indices = tvm.nd.array(indices.cpu().numpy().astype("int32"), device=tvm.cuda(0)) + + cold_start = 3 + total = 10 + accum = 0 + + for epoch in range(10): + with TorchOpTimer() as timer: + f(E, W, X, Y, indptr, indices) + if epoch >= cold_start: + accum += timer.time + + print("sparse-tir:\t\t {}ms".format(accum / (total - cold_start))) + + if y_dgl is not None: + tvm.testing.assert_allclose(y_dgl.view(-1).cpu().numpy(), Y.numpy(), rtol=1e-4) + if y_dgl_lowmem is not None: + tvm.testing.assert_allclose(y_dgl_lowmem.view(-1).cpu().numpy(), Y.numpy(), rtol=1e-4) + + +if __name__ == "__main__": + for feat_size in [4, 8, 16, 32, 64]: + for name in ["aifb", "mutag", "bgs", "am"]: + print("dataset {}, feat_size={}:".format(name, feat_size)) + dataset = get_dataset_by_name(name) + g = dataset[0] + type_pointers = prepare_hetero_graph_simplified(g) + g = dgl.to_homogeneous(g) + g.ntype_pointer = type_pointers["ntype_node_pointer"] + g.etype_pointer = type_pointers["etype_edge_pointer"] + g.num_ntypes = max(g.ndata[dgl.NTYPE]).item() + 1 + g.num_rels = max(g.edata[dgl.ETYPE]).item() + 1 + test_rgcn(g, feat_size) diff --git a/tests/python/sparsetir/bench_rgcn_new.py b/tests/python/sparsetir/bench_rgcn_new.py new file mode 100644 index 000000000000..a0ed3248fa5e --- /dev/null +++ b/tests/python/sparsetir/bench_rgcn_new.py @@ -0,0 +1,210 @@ +from dgl.heterograph import DGLHeteroGraph +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +import dgl +import dgl.function as fn +import torch as th +from tvm.script import tir as T +from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset + + +def get_dataset_by_name(name: str): + if name == 'aifb': + return AIFBDataset() + elif name == 'mutag': + return MUTAGDataset() + elif name == 'bgs': + return BGSDataset() + elif name == 'am': + return AMDataset() + else: + raise KeyError("Unknown dataset {}.".format(name)) + + +class TorchOpTimer(object): + def __enter__(self): + self.start_event = th.cuda.Event(enable_timing=True) + self.end_event = th.cuda.Event(enable_timing=True) + self.start_event.record() + return self + + def __exit__(self, type, value, traceback): + self.end_event.record() + th.cuda.synchronize() # Wait for the events to be recorded! + self.time = self.start_event.elapsed_time(self.end_event) + + +def prepare_hetero_graph_simplified(g: dgl.DGLHeteroGraph): + ntype_pointer = np.cumsum( + [0] + [g.number_of_nodes(ntype) for ntype in g.ntypes]) + + etype_pointer = [0] + for etype in g.canonical_etypes: + g_sub = g[etype] + etype_pointer.append(etype_pointer[-1] + g_sub.num_edges()) + + return{"ntype_node_pointer": th.IntTensor(ntype_pointer), "etype_edge_pointer": th.IntTensor(etype_pointer)} + + +@T.prim_func +def rgcn_hetero_forward( + w: T.handle, + x: T.handle, + y: T.handle, + indptr_i: T.handle, + indices_i: T.handle, + indptr_j: T.handle, + indices_j: T.handle, + n: T.int32, + r: T.int32, + feat_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32 +): + N = T.dense_fixed(n) + R = T.dense_fixed(r) + I = T.sparse_variable(R, (n, nnz_i), (indptr_i, indices_i), "int32") + J = T.sparse_variable(I, (n, nnz_j), (indptr_j, indices_j), "int32") + F_in = T.dense_fixed(feat_size) + F_out = T.dense_fixed(feat_size) + W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") + X = T.match_sparse_buffer(x, (N, F_in), "float32") + Y = T.match_sparse_buffer(y, (N, R, F_out), "float32") + with T.iter([R, I, F_out, J, F_in], "SSSRR", "rgcn-hetero-forward") as [ + vr, vi, vout, vj, vin + ]: + with T.init(): + Y[vi, vr, vout] = 0. + Y[vi, vr, vout] = Y[vi, vr, vout] + W[vr, vout, vin] * X[vj, vin] + + +@T.prim_func +def func(w: T.handle, x: T.handle, y: T.handle, indptr_i: T.handle, indices_i: T.handle, indptr_j: T.handle, indices_j: T.handle, n: T.int32, r: T.int32, feat_size: T.int32, nnz_i: T.int32, nnz_j: T.int32) -> None: + W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32") + X_data = T.match_buffer(x, [n * feat_size], dtype="float32") + Y_data = T.match_buffer(y, [n * r * feat_size], dtype="float32") + I_indptr = T.match_buffer(indptr_i, [r + 1], dtype="int32") + I_indices = T.match_buffer(indices_i, [nnz_i], dtype="int32") + J_indptr = T.match_buffer(indptr_j, [nnz_i + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + # body + # with T.block("root") + for v_vr in T.serial(r): + with T.block("rgcn-hetero-forward0"): + vr = T.axis.spatial(r, v_vr) + T.reads(I_indptr[0: r + 1], I_indices[0: nnz_i], J_indptr[0: nnz_i + 1], J_indices[0: nnz_j], + W_data[0: r * feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * r * feat_size]) + T.writes(Y_data[0: n * r * feat_size]) + T.block_attr({"sparse": True}) + W_data_shared = T.alloc_buffer([feat_size * feat_size], dtype="float32", scope="shared") + for ax0 in T.serial(feat_size * feat_size): + with T.block("W_data_shared"): + v0 = T.axis.spatial(feat_size * feat_size, ax0) + T.reads(W_data[feat_size * feat_size * vr + v0]) + T.writes(W_data_shared[v0]) + W_data_shared[v0] = W_data[vr * feat_size * feat_size + v0] + for v_vi, v_vout in T.grid(I_indptr[vr + 1] - I_indptr[vr], feat_size): + with T.block("rgcn-hetero-forward1"): + vi, vout = T.axis.remap("SS", [v_vi, v_vout]) + T.reads(I_indptr[0: r + 1], I_indices[0: nnz_i], J_indptr[0: nnz_i + 1], J_indices[0: nnz_j], + W_data_shared[0: feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * r * feat_size]) + T.writes(Y_data[0: n * r * feat_size]) + T.block_attr({"sparse": True}) + for v_vj, v_vin in T.grid(J_indptr[I_indptr[vr] + vi + 1] - J_indptr[I_indptr[vr] + vi], feat_size): + with T.block("rgcn-hetero-forward2"): + vj, vin = T.axis.remap("RR", [v_vj, v_vin]) + T.reads(I_indptr[0: r + 1], I_indices[0: nnz_i], J_indptr[0: nnz_i + 1], J_indices[0: nnz_j], + W_data_shared[0: feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * r * feat_size]) + T.writes(Y_data[0: n * r * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[((I_indices[I_indptr[vr] + vi]) + * r + vr) * feat_size + vout] = T.float32(0) + Y_data[((I_indices[I_indptr[vr] + vi]) * r + vr) * feat_size + vout] = Y_data[((I_indices[I_indptr[vr] + vi]) * r + vr) + * feat_size + vout] + W_data_shared[vout * feat_size + vin] * X_data[J_indices[J_indptr[I_indptr[vr] + vi] + vj] * feat_size + vin] + + +def test_lower_rgcn_hetero(g: dgl.DGLHeteroGraph, feat_size: int): + mod = tvm.IRModule.from_expr(func) + N, R, FEAT_SIZE, NNZ_I, NNZ_J = mod["main"].params[-5:] + n = g.num_nodes() + r = len(g.etypes) + nnz_j = g.num_edges() + + feat = th.rand(n, feat_size).to(0) / 100 + out = th.zeros(n, r, feat_size).to(0) / 100 + weight = th.rand(r, feat_size, feat_size).to(0) + W = tvm.nd.array(weight.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + X = tvm.nd.array(feat.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + Y = tvm.nd.array(out.view(-1).cpu().numpy().astype("float32"), device=tvm.cuda(0)) + + indptr_i = [th.LongTensor([0])] + indices_i = [] + indptr_j = [th.LongTensor([0])] + indices_j = [] + for etype in g.canonical_etypes: + src_type, _, dst_type = etype + etype_id = g.get_etype_id(etype) + src_type_id = g.get_ntype_id(src_type) + dst_type_id = g.get_ntype_id(dst_type) + g_sub = g[etype] + indptr, indices, _ = g_sub.adj_sparse(fmt="csc") + + unique_nodes = th.nonzero(indptr[:-1] != indptr[1:]).squeeze(1) + indptr_i.append(th.LongTensor([len(unique_nodes)])) + indices_i.append(unique_nodes + g.ntype_pointer[dst_type_id]) + indptr_j.append(indptr[unique_nodes] + g.etype_pointer[etype_id]) + indices_j.append(indices + g.ntype_pointer[src_type_id]) + + indptr_i = tvm.nd.array(th.cat(indptr_i).numpy().astype("int32"), device=tvm.cuda(0)) + indices_i = tvm.nd.array(th.cat(indices_i).numpy().astype("int32"), device=tvm.cuda(0)) + indptr_j = tvm.nd.array(th.cat(indptr_j).numpy().astype("int32"), device=tvm.cuda(0)) + indices_j = tvm.nd.array(th.cat(indices_j).numpy().astype("int32"), device=tvm.cuda(0)) + + nnz_i = indices_i.shape[0] + + sch = tir.Schedule( + mod["main"].specialize( + {N: n, R: r, FEAT_SIZE: feat_size, NNZ_I: nnz_i, NNZ_J: nnz_j} + ) + ) + + blk0 = sch.get_block("rgcn-hetero-forward0") + blk1 = sch.get_block("rgcn-hetero-forward1") + blk2 = sch.get_block("rgcn-hetero-forward2") + r, = sch.get_loops(blk0) + i, f_out = sch.get_loops(blk1) + j, f_in = sch.get_loops(blk2) + i1, i2 = sch.split(i, [None, 8]) + sch.bind(i2, "blockIdx.x") + sch.bind(r, "blockIdx.y") + sch.bind(f_out, "threadIdx.y") + sch.bind(f_in, "threadIdx.x") + f = tvm.build(sch.mod["main"], target="cuda") + + cold_start = 3 + total = 10 + accum = 0 + + for epoch in range(10): + with TorchOpTimer() as timer: + f(W, X, Y, indptr_i, indices_i, indptr_j, indices_j) + if epoch >= cold_start: + accum += timer.time + + print("sparse-tir:\t\t {}ms".format(accum / (total - cold_start))) + + +if __name__ == "__main__": + for feat_size in [32]: # [4, 8, 16, 32, 64]: + for name in ['bgs']: # ['aifb', 'mutag', 'bgs', 'am']: + print('dataset {}:'.format(name)) + dataset = get_dataset_by_name(name) + g = dataset[0] + type_pointers = prepare_hetero_graph_simplified(g) + g.ntype_pointer = type_pointers['ntype_node_pointer'] + g.etype_pointer = type_pointers['etype_edge_pointer'] + test_lower_rgcn_hetero(g, feat_size) diff --git a/tests/python/sparsetir/lowered_tir.py b/tests/python/sparsetir/lowered_tir.py new file mode 100644 index 000000000000..31c076e4bb29 --- /dev/null +++ b/tests/python/sparsetir/lowered_tir.py @@ -0,0 +1,464 @@ +"""Lowered TIR scripts of sparse workloads.""" +from tvm.script import tir as T + + +@T.prim_func +def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (m * k,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vk in T.grid(m, k): + with T.block("csrmm0"): + vi, vk = T.axis.remap("SS", [v_vi, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csrmm1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[J_indptr[vi] + vj] * \ + B_data[J_indices[J_indptr[vi] + vj] * k + vk] + + +@T.prim_func +def lowered_csrmm_dense_iter(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (m * k,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(m, n, k): + with T.block("csrmm0"): + vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: nnz], B_data[0: n * k], C_data[0: m * k]) + T.writes(C_data[0: m * k]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vi * k + vk] = T.float32(0) + C_data[vi * k + vk] = C_data[vi * k + vk] + A_data[T.tvm_lower_bound( + J_indices.data, vj, J_indptr[vi], J_indptr[vi + 1], dtype="int32")] * B_data[vj * k + vk] + + +@T.prim_func +def lowered_csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [n], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, n): + with T.block("csr_reduce_outer"): + vi = T.axis.spatial(n, v_vi) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csr_reduce"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: n + 1], J_indices[0: nnz], A_data[0: nnz], B_data[0: n]]) + T.writes([B_data[0: n]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + + +@T.prim_func +def lowered_segment_reduce(a: T.handle, b: T.handle, indptr: T.handle, n: T.int32, nnz: T.int32) -> None: + A_data = T.match_buffer(a, (nnz,), "float32") + B_data = T.match_buffer(b, (n,), "float32") + J_indptr = T.match_buffer(indptr, (n + 1,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi in T.serial(n): + with T.block("segment_reduce0"): + vi = T.axis.spatial(n, v_vi) + T.reads(J_indptr[0: n + 1], A_data[0: nnz], B_data[0: n]) + T.writes(B_data[0: n]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("segment_reduce1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: n + 1], A_data[0: nnz], B_data[0: n]) + T.writes(B_data[0: n]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + B_data[vi] = B_data[vi] + A_data[J_indptr[vi] + vj] + + +@T.prim_func +def lowered_bsrmm(a: T.handle, b: T.handle, c: T.handle, j_indptr: T.handle, j_indices: T.handle, nb: T.int32, mb: T.int32, nnzb: T.int32, blk: T.int32, feat_size: T.int32) -> None: + A_data = T.match_buffer(a, (nnzb * blk * blk,), "float32") + B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") + C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") + J_indptr = T.match_buffer(j_indptr, (nb + 1,), "int32") + J_indices = T.match_buffer(j_indices, (nnzb,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vbi, v_vbj, v_vf in T.grid(nb, blk, blk, feat_size): + with T.block("bsrmm0"): + vi, vbi, vbj, vf = T.axis.remap("SSRS", [v_vi, v_vbi, v_vbj, v_vf]) + T.reads(J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + with T.block("bsrmm1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads(J_indptr[0: nb + 1], J_indices[0: nnzb], A_data[0: nnzb * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[( + (J_indptr[vi] + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[J_indptr[vi] + vj] * blk + vbj) * feat_size + vf] + + +@T.prim_func +def lowered_ellmm(a: T.handle, b: T.handle, c: T.handle, j_indices: T.handle, nb: T.int32, mb: T.int32, feat_size: T.int32, col: T.int32, blk: T.int32) -> None: + A_data = T.match_buffer(a, (nb * col * blk * blk,), "float32") + B_data = T.match_buffer(b, (mb * blk * feat_size,), "float32") + C_data = T.match_buffer(c, (nb * blk * feat_size,), "float32") + J_indices = T.match_buffer(j_indices, (nb * col,), "int32") + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vj, v_vbi, v_vbj, v_vf in T.grid(nb, col, blk, blk, feat_size): + with T.block("ellmm0"): + vi, vj, vbi, vbj, vf = T.axis.remap("SRSRS", [v_vi, v_vj, v_vbi, v_vbj, v_vf]) + T.reads(J_indices[0: nb * col], A_data[0: nb * col * blk * blk], + B_data[0: mb * blk * feat_size], C_data[0: nb * blk * feat_size]) + T.writes(C_data[0: nb * blk * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[(vi * blk + vbi) * feat_size + vf] = T.float32(0) + C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * + col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] + + +@T.prim_func +def lowered_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (m * k,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (nnz,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + for v_vi in T.serial(m): + with T.block("sddmm0"): + vi = T.axis.spatial(m, v_vi) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + for v_vj, v_vk in T.grid(J_indptr[vi + 1] - J_indptr[vi], k): + with T.block("sddmm1"): + vj, vk = T.axis.remap("SR", [v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[J_indptr[vi] + vj] = T.float32(0) + C_data[J_indptr[vi] + vj] = C_data[J_indptr[vi] + vj] + \ + A_data[vi * k + vk] * B_data[J_indices[J_indptr[vi] + vj] * k + vk] + + +# from tvm.script import tir as T +@T.prim_func +def lowered_sddmm_fuse(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, (m * k,), "float32") + B_data = T.match_buffer(b, (n * k,), "float32") + C_data = T.match_buffer(c, (nnz,), "float32") + J_indptr = T.match_buffer(indptr, (m + 1,), "int32") + J_indices = T.match_buffer(indices, (nnz,), "int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(1, nnz, k): + with T.block("sddmm0"): + vi, vj, vk = T.axis.remap("SSR", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: m + 1], J_indices[0: nnz], + A_data[0: m * k], B_data[0: n * k], C_data[0: nnz]) + T.writes(C_data[0: nnz]) + T.block_attr({"sparse": True}) + with T.init(): + C_data[vj] = T.float32(0) + C_data[vj] = C_data[vj] + A_data[(T.tvm_upper_bound(J_indptr.data, vj, 0, + m + 1, dtype="int32") - 1) * k + vk] * B_data[J_indices[vj] * k + vk] + + +@T.prim_func +def lowered_bmm( + x: T.handle, + y: T.handle, + z: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_ij: T.handle, + indptr_jk: T.handle, + indptr_ik: T.handle, + batch_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_ij: T.int32, + nnz_jk: T.int32, + nnz_ik: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, (nnz_ij,), "float32") + Y_data = T.match_buffer(y, (nnz_jk,), "float32") + Z_data = T.match_buffer(z, (nnz_ik,), "float32") + I_indptr = T.match_buffer(indptr_i, (batch_size + 1,), "int32") + J_indptr = T.match_buffer(indptr_j, (batch_size + 1,), "int32") + K_indptr = T.match_buffer(indptr_k, (batch_size + 1,), "int32") + IJ_indptr = T.match_buffer(indptr_ij, (batch_size + 1,), "int32") + JK_indptr = T.match_buffer(indptr_jk, (batch_size + 1,), "int32") + IK_indptr = T.match_buffer(indptr_ik, (batch_size + 1,), "int32") + # body + # with T.block("root") + for v_vb in T.serial(batch_size): + with T.block("bmm0"): + vb = T.axis.spatial(batch_size, v_vb) + T.reads(I_indptr[0: batch_size + 1], J_indptr[0: batch_size + 1], K_indptr[0: batch_size + 1], IJ_indptr[0: batch_size + 1], + JK_indptr[0: batch_size + 1], IK_indptr[0: batch_size + 1], X_data[0: nnz_ij], Y_data[0: nnz_jk], Z_data[0: nnz_ik]) + T.writes(Z_data[0: nnz_ik]) + T.block_attr({"sparse": True}) + for v_vi, v_vj, v_vk in T.grid(I_indptr[vb + 1] - I_indptr[vb], J_indptr[vb + 1] - J_indptr[vb], K_indptr[vb + 1] - K_indptr[vb]): + with T.block("bmm1"): + vi, vj, vk = T.axis.remap("SRS", [v_vi, v_vj, v_vk]) + T.reads(I_indptr[0: batch_size + 1], J_indptr[0: batch_size + 1], K_indptr[0: batch_size + 1], IJ_indptr[0: batch_size + 1], + JK_indptr[0: batch_size + 1], IK_indptr[0: batch_size + 1], X_data[0: nnz_ij], Y_data[0: nnz_jk], Z_data[0: nnz_ik]) + T.writes(Z_data[0: nnz_ik]) + T.block_attr({"sparse": True}) + with T.init(): + Z_data[IK_indptr[vb] + vi * + (K_indptr[vb + 1] - K_indptr[vb]) + vk] = T.float32(0) + Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] = Z_data[IK_indptr[vb] + vi * (K_indptr[vb + 1] - K_indptr[vb]) + vk] + \ + X_data[IJ_indptr[vb] + vi * (J_indptr[vb + 1] - J_indptr[vb]) + vj] * \ + Y_data[JK_indptr[vb] + vj * (K_indptr[vb + 1] - K_indptr[vb]) + vk] + + +@T.prim_func +def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj]): + with T.block("square_sum"): + vk = T.axis.reduce( + K_indptr[J_indptr[vi] + vj + 1] - K_indptr[J_indptr[vi] + vj], v_vk) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K_indptr[0: nnz_j + 1], + K_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk] + + +@T.prim_func +def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz_k], dtype="float32") + B_data = T.match_buffer(b, [M], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32") + J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32") + K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32") + K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32") + K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32") + K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32") + + for v_vi in T.serial(0, M): + with T.block("square_sum_2"): + vi = T.axis.spatial(M, v_vi) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("square_sum_1"): + vj = T.axis.reduce(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + with T.init(): + B_data[vi] = T.float32(0) + for v_vk in T.serial(0, K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj]): + with T.block("square_sum"): + vk = T.axis.reduce( + K1_indptr[J_indptr[vi] + vj + 1] - K1_indptr[J_indptr[vi] + vj], v_vk) + T.reads([J_indptr[0: M + 1], J_indices[0: nnz_j], K0_indptr[0: nnz_j + 1], K0_indices[0: nnz_k], + K1_indptr[0: nnz_j + 1], K1_indices[0: nnz_k], A_data[0: nnz_k], B_data[0: M]]) + T.writes([B_data[0: M]]) + T.block_attr({"sparse": True}) + B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound( + K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")] + + +@T.prim_func +def lowered_csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + A_data = T.match_buffer(a, [nnz], dtype="float32") + B_data = T.match_buffer(b, [nnz], dtype="float32") + J_indptr = T.match_buffer(indptr, [m + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + for v_vi in T.serial(0, m): + with T.block("csr_element_wise_outer"): + vi = T.axis.spatial(m, v_vi) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(0, J_indptr[vi + 1] - J_indptr[vi]): + with T.block("csr_element_wise"): + vj = T.axis.spatial(J_indptr[vi + 1] - J_indptr[vi], v_vj) + T.reads([J_indptr[0: m + 1], J_indices[0: nnz], A_data[0: nnz]]) + T.writes([B_data[0: nnz]]) + T.block_attr({"sparse": True}) + B_data[J_indptr[vi] + vj] = A_data[J_indptr[vi] + vj] * T.float32(2.5) + + +@T.prim_func +def lowered_rgcn_forward(etype: T.handle, w: T.handle, x: T.handle, y: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, r: T.int32, feat_size: T.int32, nnz: T.int32) -> None: + E_data = T.match_buffer(etype, [nnz], dtype="int32") + W_data = T.match_buffer(w, [r * feat_size * feat_size], dtype="float32") + X_data = T.match_buffer(x, [n * feat_size], dtype="float32") + Y_data = T.match_buffer(y, [n * feat_size], dtype="float32") + J_indptr = T.match_buffer(indptr, [n + 1], dtype="int32") + J_indices = T.match_buffer(indices, [nnz], dtype="int32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for v_vi, v_vout in T.grid(n, feat_size): + with T.block("rgcn-forward_0"): + vi, vout = T.axis.remap("SS", [v_vi, v_vout]) + T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * + feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) + T.writes(Y_data[0: n * feat_size]) + T.block_attr({"sparse": True}) + for v_vj in T.serial(J_indptr[vi + 1] - J_indptr[vi]): + for v_vin in T.serial(feat_size): + with T.block("rgcn-forward_1"): + vj, vin = T.axis.remap("RR", [v_vj, v_vin]) + T.reads(J_indptr[0: n + 1], J_indices[0: nnz], E_data[0: nnz], W_data[0: r * + feat_size * feat_size], X_data[0: n * feat_size], Y_data[0: n * feat_size]) + T.writes(Y_data[0: n * feat_size]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vi * feat_size + vout] = T.float32(0) + Y_data[vi * feat_size + vout] = Y_data[vi * feat_size + vout] + W_data[( + E_data[J_indptr[vi] + vj] * feat_size + vout) * feat_size + vin] * X_data[J_indices[J_indptr[vi] + vj] * feat_size + vin] + + +@T.prim_func +def lowered_fused_reduction_4d_2d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, [nnz_l], dtype="float32") + Y_data = T.match_buffer(y, [nnz_j], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") + # body + # with T.block("root") + for v_vi, v_vj in T.grid(1, nnz_j): + with T.block("reduction_4d_2d0"): + vi, vj = T.axis.remap("SS", [v_vi, v_vj]) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + for v_vk in T.serial(K_indptr[vj + 1] - K_indptr[vj]): + with T.block("reduction_4d_2d1"): + vk = T.axis.reduce(K_indptr[vj + 1] - K_indptr[vj], v_vk) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vj] = T.float32(0) + for v_vl in T.serial(L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk]): + with T.block("reduction_4d_2d2"): + vl = T.axis.reduce( + L_indptr[K_indptr[vj] + vk + 1] - L_indptr[K_indptr[vj] + vk], v_vl) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_j]) + T.writes(Y_data[0: nnz_j]) + T.block_attr({"sparse": True}) + Y_data[vj] = Y_data[vj] + X_data[L_indptr[K_indptr[vj] + vk] + vl] + + +@T.prim_func +def lowered_fused_reduction_4d_3d(x: T.handle, y: T.handle, indptr_j: T.handle, indptr_k: T.handle, indptr_l: T.handle, n: T.int32, nnz_j: T.int32, nnz_k: T.int32, nnz_l: T.int32) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + X_data = T.match_buffer(x, [nnz_l], dtype="float32") + Y_data = T.match_buffer(y, [nnz_k], dtype="float32") + J_indptr = T.match_buffer(indptr_j, [n + 1], dtype="int32") + K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32") + L_indptr = T.match_buffer(indptr_l, [nnz_k + 1], dtype="int32") + # body + # with T.block("root") + for v_vi, v_vj, v_vk in T.grid(1, 1, nnz_k): + with T.block("reduction_4d_3d0"): + vi, vj, vk = T.axis.remap("SSS", [v_vi, v_vj, v_vk]) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k]) + T.writes(Y_data[0: nnz_k]) + T.block_attr({"sparse": True}) + for v_vl in T.serial(L_indptr[vk + 1] - L_indptr[vk]): + with T.block("reduction_4d_3d1"): + vl = T.axis.reduce(L_indptr[vk + 1] - L_indptr[vk], v_vl) + T.reads(J_indptr[0: n + 1], K_indptr[0: nnz_j + 1], + L_indptr[0: nnz_k + 1], X_data[0: nnz_l], Y_data[0: nnz_k]) + T.writes(Y_data[0: nnz_k]) + T.block_attr({"sparse": True}) + with T.init(): + Y_data[vk] = T.float32(0) + Y_data[vk] = Y_data[vk] + X_data[L_indptr[vk] + vl] diff --git a/tests/python/sparsetir/sparse_tir_scripts.py b/tests/python/sparsetir/sparse_tir_scripts.py new file mode 100644 index 000000000000..1c473026bc5d --- /dev/null +++ b/tests/python/sparsetir/sparse_tir_scripts.py @@ -0,0 +1,361 @@ +from tvm.script import tir as T + + +@T.prim_func +def csrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csrmm_dense_iter( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def segment_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + n: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (100, nnz), indptr, "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "segment_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0. + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(nb) + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, BI, BJ, F, J], "SSRSR", "bsrmm") as [ + vi, + vbi, + vbj, + vf, + vj, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def ellmm( + a: T.handle, + b: T.handle, + c: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + col: T.int32, + blk: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(nb) + J = T.sparse_fixed(I, (mb, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSRS", "ellmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def csr_element_wise( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + nnz: T.int32, +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") + + with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.5 + + +@T.prim_func +def bmm( + x: T.handle, + y: T.handle, + z: T.handle, + indptr_i: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_ij: T.handle, + indptr_jk: T.handle, + indptr_ik: T.handle, + batch_size: T.int32, + nnz_i: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_ij: T.int32, + nnz_jk: T.int32, + nnz_ik: T.int32 +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + B = T.dense_fixed(batch_size) + I = T.dense_variable(B, (32768, nnz_i), indptr_i, "int32") + J = T.dense_variable(B, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(B, (32768, nnz_k), indptr_k, "int32") + IJ = T.attach_axis(I, J, nnz_ij, indptr_ij, "int32") + JK = T.attach_axis(J, K, nnz_jk, indptr_jk, "int32") + IK = T.attach_axis(I, K, nnz_ik, indptr_ik, "int32") + X = T.match_sparse_buffer(x, (B, I, IJ), "float32") + Y = T.match_sparse_buffer(y, (B, J, JK), "float32") + Z = T.match_sparse_buffer(z, (B, I, IK), "float32") + with T.iter([B, I, J, K], "SSRS", "bmm") as [vb, vi, vj, vk]: + with T.init(): + Z[vb, vi, vk] = 0. + Z[vb, vi, vk] = Z[vb, vi, vk] + X[vb, vi, vj] * Y[vb, vj, vk] + + +@T.prim_func +def sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([I, J, K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def fused_sddmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, m: T.int32, n: T.int32, k: T.int32, nnz: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, K), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, J), "float32") + + with T.iter([T.fuse(I, J), K], "SSR", "sddmm") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0. + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k: T.handle, indices_k: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K = T.sparse_variable(J, (N2, nnz_k), (indptr_k, indices_k), "int32") + A = T.match_sparse_buffer(a, (I, J, K), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32): + # Used only for testing `GetIndicesRange()`. + # Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the + # same as `indices_k1`. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(M) + J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32") + K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32") + K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32") + A = T.match_sparse_buffer(a, (I, J, K0), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + + with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj, vk] + + +@T.prim_func +def fused_reduction_4d_2d( + x: T.handle, + y: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_l: T.handle, + n: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_l: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + X = T.match_sparse_buffer(x, (I, J, K, L), "float32") + Y = T.match_sparse_buffer(y, (I, J), "float32") + with T.iter([T.fuse(I, J), K, L], "SSRR", "reduction_4d_2d") as [vi, vj, vk, vl]: + with T.init(): + Y[vi, vj] = 0.0 + Y[vi, vj] = Y[vi, vj] + X[vi, vj, vk, vl] + + +@T.prim_func +def fused_reduction_4d_3d( + x: T.handle, + y: T.handle, + indptr_j: T.handle, + indptr_k: T.handle, + indptr_l: T.handle, + n: T.int32, + nnz_j: T.int32, + nnz_k: T.int32, + nnz_l: T.int32) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + I = T.dense_fixed(n) + J = T.dense_variable(I, (32768, nnz_j), indptr_j, "int32") + K = T.dense_variable(J, (32768, nnz_k), indptr_k, "int32") + L = T.dense_variable(K, (32768, nnz_l), indptr_l, "int32") + X = T.match_sparse_buffer(x, (I, J, K, L), "float32") + Y = T.match_sparse_buffer(y, (I, J, K), "float32") + with T.iter([T.fuse(I, J, K), L], "SSSR", "reduction_4d_3d") as [vi, vj, vk, vl]: + with T.init(): + Y[vi, vj, vk] = 0.0 + Y[vi, vj, vk] = Y[vi, vj, vk] + X[vi, vj, vk, vl] + + +@T.prim_func +def rgcn_forward( + etype: T.handle, + w: T.handle, + x: T.handle, + y: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + r: T.int32, + feat_size: T.int32, + nnz: T.int32 +): + I = T.dense_fixed(n) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + R = T.dense_fixed(r) + F_in = T.dense_fixed(feat_size) + F_out = T.dense_fixed(feat_size) + E = T.match_sparse_buffer(etype, (I, J), "int32") + W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32") + X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32") + Y = T.match_sparse_buffer(y, (I, F_out), "float32") + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [ + vi, vout, vj, vin, + ]: + with T.init(): + Y[vi, vout] = 0. + Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin] diff --git a/tests/python/sparsetir/test_butterfly.py b/tests/python/sparsetir/test_butterfly.py new file mode 100644 index 000000000000..67c7f86395f7 --- /dev/null +++ b/tests/python/sparsetir/test_butterfly.py @@ -0,0 +1,38 @@ +import tvm +import tvm.testing +from tvm.runtime.ndarray import device +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T + + +@T.prim_func +def butterfly(w1: T.handle, w2: T.handle, w3: T.handle, w4: T.handle, x: T.handle, y: T.handle) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + W1 = T.match_buffer(w1, (16, 2), "float32") + W2 = T.match_buffer(w2, (16, 2), "float32") + W3 = T.match_buffer(w3, (16, 2), "float32") + W4 = T.match_buffer(w4, (16, 2), "float32") + X = T.match_buffer(x, (16, 64), "float32") + Y = T.match_buffer(y, (16, 64), "float32") + + for i, j, k in T.grid(16, 2, 64): + with T.block("wx"): + vi, vj, vk = T.axis.remap("SRS", [i, j, k]) + with T.init(): + Y[vi, vk] = 0. + Y[vi, vk] = Y[vi, vk] +\ + W1[vi, vj] * X[vj * 8 + T.floormod(vi, 8), vk] +\ + W2[vi, vj] * X[T.floordiv(vi, 8) * 8 + vj * 4 + T.floormod(vi, 4), vk] +\ + W3[vi, vj] * X[T.floordiv(vi, 4) * 4 + vj * 2 + T.floormod(vi, 2), vk] +\ + W4[vi, vj] * X[T.floordiv(vi, 2) * 2 + vj, vk] + + +def test_butterfly(): + sch = tir.Schedule(butterfly) + print(sch.mod["main"].script()) + + +if __name__ == "__main__": + test_butterfly() diff --git a/tests/python/sparsetir/test_tir_sparse_buffer.py b/tests/python/sparsetir/test_tir_sparse_buffer.py new file mode 100644 index 000000000000..a3a099ff1f44 --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_buffer.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir + +def test_axis_creation(): + i = tir.sparse.DenseFixedAxis('i', 128) + j = tir.sparse.DenseFixedAxis('j', 128) + k = tir.sparse.DenseFixedAxis('k', 128) + print(i, j, k) + + +if __name__ == "__main__": + test_axis_creation() diff --git a/tests/python/sparsetir/test_tir_sparse_correctness.py b/tests/python/sparsetir/test_tir_sparse_correctness.py new file mode 100644 index 000000000000..506bc2998248 --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_correctness.py @@ -0,0 +1,403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T +from lowered_tir import * + + +def test_csrmm(): + A = sp.random(512, 512, dtype="float32", density=0.0125, format="csr") + x = np.random.rand(512, 128).astype("float32") + y_ground_truth = A * x + y = np.zeros((512, 128)).astype("float32") + + n, m, k, nnz = lowered_csrmm.params[-4:] + f = tvm.build(lowered_csrmm.specialize({n: 512, m: 512, k: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y.reshape(-1), device=ctx) + f(A_data, X_nd, Y_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_csr_reduce(): + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = np.array(np.sum(A, axis=1)) + b = np.zeros((128,)).astype("float32") + + n, m, nnz = lowered_csr_reduce.params[-3:] + f = tvm.build(lowered_csr_reduce.specialize({n: 128, m: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_csr_element_wise(): + A = sp.random(128, 128, dtype="float32", density=0.0125, format="csr") + b_ground_truth = A * 2.5 + b = np.zeros((A.nnz,)).astype("float32") + + m, n, nnz = lowered_csr_element_wise.params[-3:] + f = tvm.build(lowered_csr_element_wise.specialize({m: 128, n: 128, nnz: A.nnz}), target="llvm") + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(A.indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(A.data.astype("float32"), device=ctx) + B_nd = tvm.nd.array(b, device=ctx) + f(A_data, B_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_bsrmm(): + block_size = 16 + nb = 32 + mb = 32 + feat_size = 256 + n = nb * block_size + m = mb * block_size + + A_block = sp.random(mb, nb, dtype="float32", density=0.05, format="csr") + indptr = A_block.indptr + indices = A_block.indices + nnzb = A_block.nnz + data = np.random.rand(nnzb, block_size, block_size) + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") + y_ground_truth = A * x + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_nnzb, v_blk, v_feat_size = lowered_bsrmm.params[-5:] + f = tvm.build( + lowered_bsrmm.specialize( + {v_nb: nb, v_mb: mb, v_nnzb: nnzb, v_blk: block_size, v_feat_size: feat_size} + ), + target="llvm", + ) + + ctx = tvm.cpu(0) + A_indptr = tvm.nd.array(indptr.astype("int32"), device=ctx) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) + f(A_data, X_nd, Y_nd, A_indptr, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_ellmm(): + nnz_cols = 4 + nb = 64 + mb = 64 + feat_size = 1024 + nnz = nb * nnz_cols + block_size = 16 + n = nb * block_size + m = mb * block_size + + rng = np.random.default_rng() + indptr = np.arange(0, (nb + 1) * nnz_cols, nnz_cols) + indices = np.array([rng.choice(mb, size=nnz_cols, replace=False) for i in range(nb)]) + order = indices.argsort(axis=1) + indices = np.array([indices[i, order[i]] for i in range(0, nb)]).reshape(-1) + data = np.random.rand(nnz, block_size, block_size) + A = sp.bsr_matrix((data, indices, indptr), shape=(n, m)) + x = np.random.rand(m, feat_size).astype("float32") + y_ground_truth = A * x + y = np.zeros((n * feat_size,)).astype("float32") + + v_nb, v_mb, v_feat_size, v_col, v_blk = lowered_ellmm.params[-5:] + f = tvm.build( + lowered_ellmm.specialize( + { + v_nb: nb, + v_mb: mb, + v_feat_size: feat_size, + v_col: nnz_cols, + v_blk: block_size, + } + ), + target="llvm", + ) + + ctx = tvm.cpu(0) + A_indices = tvm.nd.array(indices.astype("int32"), device=ctx) + A_data = tvm.nd.array(data.reshape(-1).astype("float32"), device=ctx) + X_nd = tvm.nd.array(x.reshape(-1), device=ctx) + Y_nd = tvm.nd.array(y, device=ctx) + f(A_data, X_nd, Y_nd, A_indices) + tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) + + +def test_sddmm(): + # generate random input + m = 4096 + n = 4096 + k = 256 + C = sp.random(m, n, dtype="float32", density=0.0125, format='csr') + indptr = C.indptr + indices = C.indices + C_coo = C.tocoo() + nnz = C.nnz + x = np.random.rand(m, k).astype("float32") + y = np.random.rand(n, k).astype("float32") + z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col] + z = np.zeros((nnz,)).astype("float32") + + # specialize function + _, _, _, _, _, M, N, K, NNZ = lowered_sddmm.params + sch = tir.Schedule( + lowered_sddmm.specialize( + {M: m, N: n, K: k, NNZ: nnz} + ) + ) + blk_outer = sch.get_block("sddmm0") + blk_inner = sch.get_block("sddmm1") + i, = sch.get_loops(blk_outer) + _, k = sch.get_loops(blk_inner) + sch.bind(i, "blockIdx.x") + sch.bind(k, "threadIdx.x") + + # convert numpy tensor to tvm ndarray + C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) + C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0)) + X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) + Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0)) + C_data = tvm.nd.array(z, device=tvm.cuda(0)) + + # build function + f = tvm.build(sch.mod['main'], target="cuda") + f(X_nd, Y_nd, C_data, C_indptr, C_indices) + + # assertion + tvm.testing.assert_allclose(z_ground_truth, C_data.numpy(), rtol=1e-5) + + +def test_sddmm_fuse(): + # generate random input + m = 4096 + n = 4096 + k = 256 + C = sp.random(m, n, dtype="float32", density=0.0125, format='csr') + indptr = C.indptr + indices = C.indices + C_coo = C.tocoo() + nnz = C.nnz + x = np.random.rand(m, k).astype("float32") + y = np.random.rand(n, k).astype("float32") + z_ground_truth = np.matmul(x, y.transpose())[C_coo.row, C_coo.col] + z = np.zeros((nnz,)).astype("float32") + + # specialize function + _, _, _, _, _, M, N, K, NNZ = lowered_sddmm_fuse.params + sch = tir.Schedule( + lowered_sddmm_fuse.specialize( + {M: m, N: n, K: k, NNZ: nnz} + ) + ) + blk = sch.get_block("sddmm0") + i, j, k = sch.get_loops(blk) + sch.unroll(i) + sch.bind(j, "blockIdx.x") + sch.bind(k, "threadIdx.x") + + # convert numpy tensor to tvm ndarray + C_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0)) + C_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0)) + X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0)) + Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0)) + C_data = tvm.nd.array(z, device=tvm.cuda(0)) + + # build function + f = tvm.build(sch.mod['main'], target="cuda") + f(X_nd, Y_nd, C_data, C_indptr, C_indices) + + # assertion + tvm.testing.assert_allclose(z_ground_truth, C_data.numpy(), rtol=1e-5) + + +def test_bmm(): + # generate random input + batch_size = 32 + n_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + m_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + k_arr = np.random.randint(128, 1024, size=(batch_size,)).astype("int32") + nm_arr = n_arr * m_arr + mk_arr = m_arr * k_arr + nk_arr = n_arr * k_arr + indptr_n = np.concatenate(([0], n_arr)).cumsum() + indptr_m = np.concatenate(([0], m_arr)).cumsum() + indptr_k = np.concatenate(([0], k_arr)).cumsum() + indptr_nm = np.concatenate(([0], nm_arr)).cumsum() + indptr_mk = np.concatenate(([0], mk_arr)).cumsum() + indptr_nk = np.concatenate(([0], nk_arr)).cumsum() + nnz_i = indptr_n[-1] + nnz_j = indptr_m[-1] + nnz_k = indptr_k[-1] + nnz_ij = indptr_nm[-1] + nnz_jk = indptr_mk[-1] + nnz_ik = indptr_nk[-1] + As = [ + np.random.rand(n, m).astype("float32") for n, m in zip(n_arr, m_arr) + ] + Bs = [ + np.random.rand(m, k).astype("float32") for m, k in zip(m_arr, k_arr) + ] + Cs = [ + np.matmul(A, B) for A, B in zip(As, Bs) + ] + A_flatten = np.concatenate([A.flatten() for A in As], 0) + B_flatten = np.concatenate([B.flatten() for B in Bs], 0) + c_flatten = np.concatenate([C.flatten() for C in Cs], 0) + + # specialize function + _, _, _, _, _, _, _, _, _, BATCH, NNZ_I, NNZ_J, NNZ_K, NNZ_IJ, NNZ_JK, NNZ_IK = lowered_bmm.params + sch = tir.Schedule( + lowered_bmm.specialize({ + BATCH: batch_size, NNZ_I: nnz_i, NNZ_J: nnz_j, NNZ_K: nnz_k, NNZ_IJ: nnz_ij, NNZ_JK: nnz_jk, NNZ_IK: nnz_ik + }) + ) + bmm_outer = sch.get_block("bmm0") + b, = sch.get_loops(bmm_outer) + bmm_inner = sch.get_block("bmm1") + i, j, k = sch.get_loops(bmm_inner) + sch.reorder(i, k, j) + io, ii = sch.split(i, [None, 32]) + ko, ki = sch.split(k, [None, 32]) + sch.bind(b, "blockIdx.x") + sch.bind(ki, "threadIdx.x") + sch.bind(ii, "threadIdx.y") + sch.decompose_reduction(bmm_inner, j) + + # convert numpy tensor to tvm ndarray + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_flatten, device=dev) + B_nd = tvm.nd.array(B_flatten, device=dev) + C_nd = tvm.nd.array(np.zeros_like(c_flatten), device=dev) + indptr_n_nd = tvm.nd.array(indptr_n.astype("int32"), device=dev) + indptr_m_nd = tvm.nd.array(indptr_m.astype("int32"), device=dev) + indptr_k_nd = tvm.nd.array(indptr_k.astype("int32"), device=dev) + indptr_nm_nd = tvm.nd.array(indptr_nm.astype("int32"), device=dev) + indptr_mk_nd = tvm.nd.array(indptr_mk.astype("int32"), device=dev) + indptr_nk_nd = tvm.nd.array(indptr_nk.astype("int32"), device=dev) + + # build function + f = tvm.build(sch.mod["main"], target="cuda") + f(A_nd, B_nd, C_nd, indptr_n_nd, indptr_m_nd, indptr_k_nd, indptr_nm_nd, indptr_mk_nd, indptr_nk_nd) + + # assertion + tvm.testing.assert_allclose(C_nd.numpy(), c_flatten, rtol=1e-5) + + +def test_square_sum(): + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = lowered_square_sum.params[-5:] + f = tvm.build(lowered_square_sum.specialize( + {v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="llvm") + + ctx = tvm.cpu(0) + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k, A_indices_k) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) + + +def test_square_sum_two_K(): + sch = tir.Schedule(lowered_square_sum_two_K, debug_mask="all") + i, = sch.get_loops(sch.get_block("square_sum_2")) + sch.bind(i, "threadIdx.x") + + density = 0.0125 + M = N1 = N2 = 128 + A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr") + indptr_j = A_J.indptr + indices_j = A_J.indices + nnz_j = A_J.nnz + A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr") + indptr_k = A_K.indptr + indices_k = A_K.indices + nnz_k = A_K.nnz + data = A_K.data + + b_ij = np.asarray(A_K.sum(axis=1)).squeeze() + A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1)) + b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze() + b = np.zeros((M,)).astype("float32") + + v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = sch.mod["main"].params[-5:] + f = tvm.build(sch.mod["main"].specialize( + {v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda") + + ctx = tvm.device("cuda") + A_data = tvm.nd.array(data.astype("float32"), device=ctx) + A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx) + A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx) + A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx) + A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx) + B_data = tvm.nd.array(b.astype("float32"), device=ctx) + f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1) + + tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5) + + +if __name__ == "__main__": + test_csrmm() + test_csr_reduce() + test_csr_element_wise() + test_bsrmm() + test_ellmm() + test_sddmm() + test_sddmm_fuse() + test_bmm() + test_square_sum() + test_square_sum_two_K() diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py new file mode 100644 index 000000000000..3a7906451e68 --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +import pytest +from lowered_tir import * +from sparse_tir_scripts import * + + +def test_csrmm(): + mod = tvm.IRModule.from_expr(csrmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) + + +def test_csrmm_dense_iter(): + mod = tvm.IRModule.from_expr(csrmm_dense_iter) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm_dense_iter, True) + + +def test_segment_reduce(): + mod = tvm.IRModule.from_expr(segment_reduce) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_segment_reduce, True) + + +def test_csr_reduce(): + mod = tvm.IRModule.from_expr(csr_reduce) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csr_reduce, True) + + +def test_bsrmm(): + mod = tvm.IRModule.from_expr(bsrmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_bsrmm, True) + + +def test_ellpack_mm(): + mod = tvm.IRModule.from_expr(ellmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_ellmm, True) + + +def test_csr_element_wise(): + mod = tvm.IRModule.from_expr(csr_element_wise) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_csr_element_wise, True) + + +def test_bmm(): + mod = tvm.IRModule.from_expr(bmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_bmm) + + +def test_sddmm(): + mod = tvm.IRModule.from_expr(sddmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm) + + +def test_fused_sddmm(): + mod = tvm.IRModule.from_expr(fused_sddmm) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_sddmm_fuse) + + +def test_square_sum(): + mod = tvm.IRModule.from_expr(square_sum) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum, True) + + +def test_square_sum_two_K(): + mod = tvm.IRModule.from_expr(square_sum_two_K) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True) + + +def test_fused_reduction(): + mod = tvm.IRModule.from_expr(fused_reduction_4d_2d) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_2d, True) + + mod = tvm.IRModule.from_expr(fused_reduction_4d_3d) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + tvm.ir.assert_structural_equal(mod["main"], lowered_fused_reduction_4d_3d, True) + + +if __name__ == "__main__": + test_csrmm() + test_csrmm_dense_iter() + test_segment_reduce() + test_csr_reduce() + test_bsrmm() + test_ellpack_mm() + test_csr_element_wise() + test_sddmm() + test_fused_sddmm() + test_bmm() + test_square_sum() + test_square_sum_two_K() + test_fused_reduction() diff --git a/tests/python/sparsetir/test_tir_sparse_nnz_inference.py b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py new file mode 100644 index 000000000000..f521503775ca --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_nnz_inference.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T + +@T.prim_func +def csr2bsr_cnt_nnz( + indptr: T.handle, indices: T.handle, + new_cord: T.handle, glb_counter: T.handle, + n: T.int32, m: T.int32, nnz: T.int32) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + K = T.dense_fixed(2) + New_cord = T.match_sparse_buffer(new_cord, (I, J, K), "int32") + with T.iter([I, J], "SS", "csr2bsr_cnt_nnz") as [vi, vj]: + New_cord[vi, vj, 0] = 0 + New_cord[vi, vj, 1] = 1 + + +@T.prim_func +def csr2bsr(indptr_1: T.handle, indices_1: T.handle, indptr_2: T.handle, indices_2: T.handle, + a_csr: T.handle, a_bsr: T.handle, + block_size: T.int32, + n: T.int32, m: T.int32, nnz: T.int32, + nb: T.int32, mb: T.int32, nnzb: T.int32) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr_1, indices_1), "int32") + Ibo = T.dense_fixed(nb) + Jbo = T.sparse_variable(Ibo, (mb, nnzb), (indptr_2, indices_2), "int32") + Ibi = T.dense_fixed(block_size) + Jbi = T.dense_fixed(block_size) + A_csr = T.match_sparse_buffer(a_csr, (I, J), "float32") + A_bsr = T.match_sparse_buffer(a_bsr, (Ibo, Jbo, Ibi, Jbi), "float32") + with T.iter([I, J], "SS", "csr2bsrm") as [vi, vj]: + A_bsr[T.floordiv(vi, block_size), T.floordiv(vj, block_size), T.floormod(vi, block_size), T.floormod(vj, block_size)] =\ + A_csr[vi, vj] + + +def test_cnt_nnz(): + mod = tvm.IRModule.from_expr(csr2bsr_cnt_nnz) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod['main'].script()) + + +def test_csr2bsr(): + mod = tvm.IRModule.from_expr(csr2bsr) + mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod['main'].script()) + + +if __name__ == "__main__": + test_cnt_nnz() + test_csr2bsr() \ No newline at end of file diff --git a/tests/python/sparsetir/test_tir_sparse_schedule.py b/tests/python/sparsetir/test_tir_sparse_schedule.py new file mode 100644 index 000000000000..7be256a2c1df --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_schedule.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +import tvm.tir as tir +import scipy.sparse as sp +import numpy as np +from tvm.script import tir as T +from scipy.sparse import bsr +import pytest + + +@T.prim_func +def csrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, K, J], "SSR", "csrmm") as [vi, vk, vj]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csr_reduce( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + I = T.dense_fixed(nb) + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSRS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def ellpack_mm( + a: T.handle, + b: T.handle, + c: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + feat_size: T.int32, + col: T.int32, + blk: T.int32, +) -> None: + I = T.dense_fixed(nb) + J = T.sparse_fixed(I, (mb, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSRS", "ellmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def csr_element_wise( + a: T.handle, + b: T.handle, + indptr: T.handle, + indices: T.handle, + m: T.int32, + n: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") + + with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.5 + + +@T.prim_func +def reordered_bsrmm( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + nb: T.int32, + mb: T.int32, + nnzb: T.int32, + blk: T.int32, + feat_size: T.int32, +) -> None: + I = T.dense_fixed(nb) + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([BI, BJ, I, J, F], "SRSRS", "bsrmm") as [ + vbi, + vbj, + vi, + vj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +def test_get_sparse_block(): + sch = tir.Schedule(csrmm, debug_mask="all") + block_rv = sch.get_sparse_block("csrmm") + block = sch.get(block_rv) + assert block.name == "csrmm" + assert block.same_as(csrmm.body) + + +def test_get_sp_iters(): + sch = tir.Schedule(csrmm, debug_mask="all") + block = sch.get_sparse_block("csrmm") + vi, vj, vk = sch.get_sp_iters(block) + assert vi.same_as(csrmm.body.sp_iter_vars[0]) + assert vj.same_as(csrmm.body.sp_iter_vars[1]) + assert vk.same_as(csrmm.body.sp_iter_vars[2]) + + +def test_reorder(): + sch = tir.Schedule(bsrmm, debug_mask="all") + block = sch.get_sparse_block("bsrmm") + i, j, bi, bj, f = sch.get_sp_iters(block) + sch.sparse_reorder(block, [bi, bj, i, j, f]) + tvm.ir.assert_structural_equal(sch.mod["main"], reordered_bsrmm, True) + assert sch.get(block).name == "bsrmm" + + +def test_reorder_fail_on_dependency(): + sch = tir.Schedule(bsrmm, debug_mask="all") + block = sch.get_sparse_block("bsrmm") + i, j, bi, bj, f = sch.get_sp_iters(block) + with pytest.raises(tvm.tir.ScheduleError): + sch.sparse_reorder(block, [bi, bj, j, i, f]) + + +def test_reorder_fail_on_new_order_length(): + sch = tir.Schedule(bsrmm, debug_mask="all") + block = sch.get_sparse_block("bsrmm") + i, j, bi, bj, f = sch.get_sp_iters(block) + with pytest.raises(tvm.tir.ScheduleError): + sch.sparse_reorder(block, [bi, bj, i, j]) + + +if __name__ == "__main__": + test_get_sparse_block() + test_get_sp_iters() + test_reorder() + test_reorder_fail_on_dependency() + test_reorder_fail_on_new_order_length() diff --git a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py new file mode 100644 index 000000000000..aedc2c22ec3c --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.tir as tir +import tvm.te as te +from tvm.script import tir as T + + +@T.prim_func +def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + k = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, J, K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csrmm_dense_iter(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + k = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), "float32") + C = T.match_sparse_buffer(c, (I, K), "float32") + with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + +@T.prim_func +def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable(I, (m, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I,), "float32") + with T.iter([I, J], "SR", "csr_reduce") as [vi, vj]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vj] + + +@T.prim_func +def bsrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + nb = T.var("int32") + mb = T.var("int32") + nnzb = T.var("int32") + blk = T.var("int32") + feat_size = T.var("int32") + I = T.dense_fixed(nb) + J = T.sparse_variable(I, (mb, nnzb), (indptr, indices), "int32") + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + F = T.dense_fixed(feat_size) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSSS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle) -> None: + nb = T.var("int32") + mb = T.var("int32") + feat_size = T.var("int32") + col = T.var("int32") + blk = T.var("int32") + I = T.dense_fixed(nb) + J = T.sparse_fixed(I, (mb, col), indices, "int32") + F = T.dense_fixed(feat_size) + BI = T.dense_fixed(blk) + BJ = T.dense_fixed(blk) + A = T.match_sparse_buffer(a, (I, J, BI, BJ), "float32") + B = T.match_sparse_buffer(b, (T.dense(J), BJ, F), "float32") + C = T.match_sparse_buffer(c, (I, BI, F), "float32") + + with T.iter([I, J, BI, BJ, F], "SRSSS", "bsrmm") as [ + vi, + vj, + vbi, + vbj, + vf, + ]: + with T.init(): + C[vi, vbi, vf] = 0.0 + C[vi, vbi, vf] = C[vi, vbi, vf] + A[vi, vj, vbi, vbj] * B[vj, vbj, vf] + + +@T.prim_func +def csr_element_wise(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle): + m = T.var("int32") + n = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(m) + J = T.sparse_variable(I, (n, nnz), (indptr, indices), "int32") + A = T.match_sparse_buffer(a, (I, J), "float32") + B = T.match_sparse_buffer(b, (I, J), "float32") + + with T.iter([I, J], "SS", "csr_element_wise") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +def test_csrmm(): + func = csrmm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_csrmm_dense_iter(): + func = csrmm_dense_iter + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_csr_reduce(): + func = csr_reduce + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_bsrmm(): + func = bsrmm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_ellpack_mm(): + func = ellpack_mm + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +def test_csr_element_wise(): + func = csr_element_wise + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + +if __name__ == "__main__": + test_csrmm() + test_csrmm_dense_iter() + test_csr_reduce() + test_bsrmm() + test_ellpack_mm() + test_csr_element_wise() diff --git a/tests/python/sparsetir/test_tir_sparse_tensorize.py b/tests/python/sparsetir/test_tir_sparse_tensorize.py new file mode 100644 index 000000000000..f87f5c14cbbd --- /dev/null +++ b/tests/python/sparsetir/test_tir_sparse_tensorize.py @@ -0,0 +1 @@ +# TODO \ No newline at end of file diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 3e9e7fd33fd9..7a26b509c855 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -253,11 +253,95 @@ def test_fma(): assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" +@T.prim_func +def binary_search(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + n = T.var('int32') + m = T.var('int32') + A = T.match_buffer(a, (n,), dtype='int32') + B = T.match_buffer(b, (m,), dtype='int32') + C = T.match_buffer(c, (m,), dtype='int32') + D = T.match_buffer(d, (m,), dtype='int32') + for i in T.serial(0, m): + with T.block('search'): + vi = T.axis.S(m, i) + T.reads([A[0:n], B[vi]]) + T.writes([C[vi], D[vi]]) + C[vi] = T.lower_bound(A.data, B[vi], 0, n) + D[vi] = T.upper_bound(A.data, B[vi], 0, n) + + +@T.prim_func +def global_add(a: T.handle) -> None: + A = T.match_buffer(a, (1,), dtype='int32') + for i in T.serial(0, 1024): + with T.block('global_add'): + T.block_attr({ + "atomic": True + }) + T.reads([A[0:1]]) + T.writes([A[0:1]]) + vi = T.axis.S(1024, i) + T.evaluate(T.atomic_add(A.data, vi)) + + +def test_binary_search(): + sch = tir.Schedule(binary_search) + b = sch.get_block('search') + i, = sch.get_loops(b) + io, ii = sch.split(i, [1, None]) + sch.bind(io, 'threadIdx.x') + sch.bind(ii, 'blockIdx.x') + f = tvm.build(sch.mod['main'], target='cuda') + # print(f.imported_modules[0].get_source()) + + x = np.arange(-128, 128).astype(np.int32) + y = np.random.randint(-200, 200, size=1024).astype(np.int32) + a = np.zeros((1024,)).astype(np.int32) + b = np.zeros((1024,)).astype(np.int32) + + # numpy results + np_a = np.searchsorted(x, y, side='left').astype(np.int32) + np_b = np.searchsorted(x, y, side='right').astype(np.int32) + + # tvm results + dev = tvm.cuda(0) + x_array = tvm.nd.array(x, device=dev) + y_array = tvm.nd.array(y, device=dev) + a_array = tvm.nd.array(a, device=dev) + b_array = tvm.nd.array(b, device=dev) + f(x_array, y_array, a_array, b_array) + tvm_a = a_array.numpy() + tvm_b = b_array.numpy() + + # verify result + tvm.testing.assert_allclose(np_a, tvm_a) + tvm.testing.assert_allclose(np_b, tvm_b) + + +def test_global_add(): + sch = tir.Schedule(global_add) + b = sch.get_block('global_add') + i, = sch.get_loops(b) + sch.bind(i, 'blockIdx.x') + f = tvm.build(sch.mod['main'], target='cuda') + + # create input and run kernel + dev = tvm.cuda(0) + a = np.zeros((1,)).astype(np.int32) + a_gpu = tvm.nd.array(a, device=dev) + f(a_gpu) + + # check output + tvm.testing.assert_allclose(a_gpu.numpy(), np.array([1024 * 1023 / 2]).astype(np.int32)) + + if __name__ == "__main__": test_nearbyint() test_unary_intrin() test_round_intrinsics_on_int() test_binary_intrin() test_ldexp() - test_clz() + # test_clz() test_fma() + test_binary_search() + test_global_add()