Skip to content

Commit

Permalink
[SparseTIR] Constructors and Python Interface for Axis and `SparseB…
Browse files Browse the repository at this point in the history
…uffer` (#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface
  • Loading branch information
MasterJH5574 authored and yzh119 committed Nov 16, 2021
1 parent 669a6c9 commit aaf686b
Show file tree
Hide file tree
Showing 3 changed files with 380 additions and 10 deletions.
120 changes: 111 additions & 9 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ class AxisNode : public Object {
/* length of current axis. For sparse axis, length refers to the upperbound of
* the current axis. */
PrimExpr length;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(AxisNode, Object);
};

Expand Down Expand Up @@ -98,6 +101,20 @@ class DenseAxis : public Axis {
*/
class DenseFixedAxisNode : public DenseAxisNode {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
}

bool SEqualReduce(const DenseAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand All @@ -108,12 +125,31 @@ class DenseFixedAxisNode : public DenseAxisNode {
*/
class DenseFixedAxis : public DenseAxis {
public:
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length);

TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
};

class DenseVariableAxisNode : public DenseAxisNode {
public:
Buffer indptr;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("indptr", &indptr);
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
}

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand All @@ -124,8 +160,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
DenseVariableAxisNode);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};

/*!
Expand Down Expand Up @@ -154,6 +191,26 @@ class SparseFixedAxisNode : public SparseAxisNode {
Buffer indices;
/* fixed number of columns of current sparse axis. */
PrimExpr num_cols;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("indptr", &indices);
v->Visit("num_cols", &num_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(indices);
hash_reduce(num_cols);
}

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand All @@ -164,17 +221,39 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
SparseFixedAxisNode);
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
};

/*!
* \brief Sparse axis with variable number of non-zero columns per row.
*/
class SparseVariableAxisNode : public SparseAxisNode {
public:
Buffer indptr, indices;
static constexpr const char* _type_key = "tir.sparse.SparseVariabledAxis";
Buffer indptr;
Buffer indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("indptr", &indptr);
v->Visit("indices", &indices);
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(indptr);
hash_reduce(indices);
}

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};

Expand All @@ -184,8 +263,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
*/
class SparseVariableAxis : public SparseAxis {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
SparseVariableAxisNode);
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);

TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};

/*!
Expand Down Expand Up @@ -223,6 +303,26 @@ class SparseBufferNode : public Object {
int ndim;
/* Buffer corresponding to flattened value */
Buffer data;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &root);
v->Visit("length", &axes);
v->Visit("indptr", &ndim);
v->Visit("num_cols", &data);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(root, other->root) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
equal(data, other->data);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root);
hash_reduce(axes);
hash_reduce(ndim);
hash_reduce(data);
}

static constexpr const char* _type_key = "tir.sparse.SparseBufferNode";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
};
Expand All @@ -233,11 +333,13 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

} // namespace sparse
} // namespace tir
} // namespace tvm

#endif // TVM_TIR_BUFFER_H_
#endif // TVM_TIR_SPARSE_H_
181 changes: 181 additions & 0 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""SparseTIR axes and SparseBuffer
"""
from typing import List
import tvm._ffi
from tvm.ir import PrimExpr
from tvm.runtime import Object, const

from . import _ffi_api
from .buffer import Buffer


class Axis(Object):
"""Base class of all the sparse axes."""


class DenseAxis(Axis):
pass


class SparseAxis(Axis):
pass


@tvm._ffi.register_object("tir.sparse.DenseFixedAxis")
class DenseFixedAxis(DenseAxis):
"""DenseFixedAxis node
Parameters
----------
name : str
The name of the axis
length : PrimExpr
The length of the axis
"""

name: str
length: PrimExpr

def __init__(self, name, length):
self.__init_handle_by_constructor__(
_ffi_api.DenseFixedAxis, name, length # type: ignore
)


@tvm._ffi.register_object("tir.sparse.DenseVariableAxis")
class DenseVariableAxis(DenseAxis):
"""DenseVariableAxis node
Parameters
----------
name : str
The name of the axis
length : PrimExpr
The length of the axis
indptr : Buffer
The indptr buffer of the axis
"""

name: str
length: PrimExpr
indptr: Buffer

def __init__(self, name, length, indptr):
self.__init_handle_by_constructor__(
_ffi_api.DenseVariableAxis, name, length, indptr # type: ignore
)


@tvm._ffi.register_object("tir.sparse.SparseFixedAxis")
class SparseFixedAxis(DenseAxis):
"""SparseFixedAxis node
Parameters
----------
name : str
The name of the axis
length : PrimExpr
The length of the axis
indices : Buffer
The indices buffer of the axis
num_cols : PrimExpr
The number of non-zero elements along the axis
"""

name: str
length: PrimExpr
indices: Buffer
num_cols: PrimExpr

def __init__(self, name, length, indices, num_cols):
self.__init_handle_by_constructor__(
_ffi_api.SparseFixedAxis, name, length, indices, num_cols # type: ignore
)


@tvm._ffi.register_object("tir.sparse.SparseVariableAxis")
class SparseVariableAxis(DenseAxis):
"""SparseVariableAxis node
Parameters
----------
name : str
The name of the axis
length : PrimExpr
The length of the axis
indptr : Buffer
The indptr buffer of the axis
indices : Buffer
The indices buffer of the axis
"""

name: str
length: PrimExpr
indptr: Buffer
indices: Buffer

def __init__(self, name, length, indptr, indices):
self.__init_handle_by_constructor__(
_ffi_api.SparseVariableAxis, name, length, indptr, indices # type: ignore
)


@tvm._ffi.register_object("tir.sparse.AxisTree")
class AxisTree:
# Todo(@ruihang): to do later
pass


@tvm._ffi.register_object("tir.sparse.SparseBuffer")
class SparseBuffer:
"""SparseBuffer node
Parameters
----------
root : AxisTree
The root of the axis dependency tree of the sparse buffer
axes : List[Axis]
The axes of the sparse buffer
ndim : int
The number of dimensions of the sparse buffer
data : Buffer
The data of the sparse buffer
"""

root: AxisTree
axes: List[Axis]
ndim: int
data: Buffer

def __init__(self, root, axes, ndim, data):
self.__init_handle_by_constructor__(
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
)
Loading

0 comments on commit aaf686b

Please sign in to comment.