diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index e184dc050856..ac40fea615a1 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -355,6 +355,64 @@ class SparseBuffer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); }; +enum class SpIterKind : int { + kDenseFixed = 0, + kDenseVariable = 1, + kSparseFixed = 2, + kSparseVariable = 3 +}; + +/*! + * \brief Iterator variables in SparseTIR + */ +class SpIterVarNode : public Object { + public: + Var var; + PrimExpr max_extent; + SpIterKind kind; + Optional axis; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("var", &var); + v->Visit("max_extent", &max_extent); + v->Visit("axis", &axis); + v->Visit("kind", &kind); + } + + bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const { + return equal(var, other->var) && equal(max_extent, other->max_extent) && + equal(axis, other->axis) && equal(kind, other->kind); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(var); + hash_reduce(max_extent); + hash_reduce(axis); + hash_reduce(kind); + } + + static constexpr const char* _type_key = "tir.sparse.SpIterVar"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(SpIterVarNode, Object); +}; + +class SpIterVar : public ObjectRef { + public: + TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, + Optional axis = NullOpt); + + /*! + * \return the corresponding var in the IterVar. + */ + inline operator PrimExpr() const; + + TVM_DEFINE_OBJECT_REF_METHODS(SpIterVar, ObjectRef, SpIterVarNode); +}; + +// inline implementations +inline SpIterVar::operator PrimExpr() const { return (*this)->var; } + } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 4ec289aa70ed..7f5c38585980 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -20,6 +20,7 @@ import tvm._ffi from tvm.ir import PrimExpr from tvm.runtime import Object, const +from tvm.tir import Var from . import _ffi_api from .buffer import Buffer @@ -166,7 +167,7 @@ def __init__(self, axis_parent_map) -> None: @tvm._ffi.register_object("tir.sparse.SparseBuffer") -class SparseBuffer: +class SparseBuffer(Object): """SparseBuffer node Parameters @@ -197,3 +198,39 @@ def __init__(self, tree, axes, data, name, dtype=None): self.__init_handle_by_constructor__( _ffi_api.SparseBuffer, tree, axes, data, name, dtype # type: ignore ) + + +@tvm._ffi.register_object("tir.sparse.SpIterVar") +class SpIterVar(Object): + """IterVar in SparseTIR + + Parameters + ---------- + var : Var + The var of the SpIterVar + + max_extent : PrimExpr + The maximum extent of the SpIterVar + + kind : int + The kind of the SpIterVar + + axis : Optional[Axis] + The axis over which the SpIterVar iterates. Required to be defined + when `kind` is not `DenseFixed` + """ + var: Var + max_extent: PrimExpr + kind: int + axis: Optional[Axis] + + DenseFixed = 0 + DenseVariable = 1 + SparseFixed = 2 + SparseVariable = 3 + + def __init__(self, var, max_extent, kind, axis=None): + self.__init_handle_by_constructor__( + _ffi_api.SpIterVar, var, max_extent, kind, axis # type: ignore + ) + diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index f8519865666c..17eca58bcf7a 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -163,5 +163,28 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") return SparseBuffer(tree, axes, data, name, dtype); }); +// SpIterVar +SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, Optional axis) { + ObjectPtr node = make_object(); + + if (kind != SpIterKind::kDenseFixed) { + CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must " + "specify the axis over which the SpIterVar iterates"; + } + + node->var = Var(std::move(name)); + node->max_extent = std::move(max_extent); + node->kind = kind; + node->axis = std::move(axis); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(SpIterVarNode); + +TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") + .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, Optional axis) { + return SpIterVar(name, max_extent, kind, axis); + }); + } // namespace tir } // namespace tvm