Skip to content

Commit

Permalink
Fix AxisTree (#3)
Browse files Browse the repository at this point in the history
* fix axis tree

* upd
  • Loading branch information
yzh119 authored and MasterJH5574 committed Jan 26, 2022
1 parent b49b752 commit 8c0c440
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 30 deletions.
29 changes: 19 additions & 10 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,15 +268,23 @@ class SparseVariableAxis : public SparseAxis {
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};


/*!
* \brief Axis Dependency Tree.
*/
class AxisTreeNode : public Object {
public:
// parent refers to the parent axis of current axis tree.
Optional<AxisTree> parent;
Axis axis;
Array<AxisTree> children;
// mapping from names to axes.
std::unordered_map<String, Axis> axis_map;
// unordered map that stores the parent relationship between axes.
std::unordered_map<Axis, Axis, ObjectPtrHash, ObjectPtrEqual> parent;
// unordered map that stores the children relationship between axes.
std::unordered_map<Axis, Array<Axis>, ObjectPtrHash, ObjectPtrEqual> children;
// The root axis.
Axis root;

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "tir.sparse.AxisTree";
TVM_DECLARE_FINAL_OBJECT_INFO(AxisTreeNode, Object);
};
Expand All @@ -287,6 +295,7 @@ class AxisTreeNode : public Object {
*/
class AxisTree : public ObjectRef {
public:
TVM_DLL AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names);
TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
};

Expand All @@ -296,7 +305,7 @@ class AxisTree : public ObjectRef {
class SparseBufferNode : public Object {
public:
/* Root of Axis Dependency Tree. */
AxisTree root;
AxisTree tree;
/* Axes */
Array<Axis> axes;
/* Number of dimensions */
Expand All @@ -305,25 +314,25 @@ class SparseBufferNode : public Object {
Buffer data;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &root);
v->Visit("name", &tree);
v->Visit("length", &axes);
v->Visit("indptr", &ndim);
v->Visit("num_cols", &data);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(root, other->root) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
return equal(tree, other->tree) && equal(axes, other->axes) && equal(ndim, other->ndim) &&
equal(data, other->data);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(root);
hash_reduce(tree);
hash_reduce(axes);
hash_reduce(ndim);
hash_reduce(data);
}

static constexpr const char* _type_key = "tir.sparse.SparseBufferNode";
static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBufferNode, Object);
};

Expand All @@ -333,7 +342,7 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data);
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim, Buffer data);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@
from . import analysis
from . import stmt_functor
from . import usmp
from . import sparse
1 change: 1 addition & 0 deletions python/tvm/tir/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@


tvm._ffi._init_api("tir", __name__)
tvm._ffi._init_api("tir.sparse", __name__)
30 changes: 22 additions & 8 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""SparseTIR axes and SparseBuffer
"""
from typing import List
from typing import List, Dict, Optional
import tvm._ffi
from tvm.ir import PrimExpr
from tvm.runtime import Object, const
Expand Down Expand Up @@ -146,9 +146,23 @@ def __init__(self, name, length, indptr, indices):


@tvm._ffi.register_object("tir.sparse.AxisTree")
class AxisTree:
# Todo(@ruihang): to do later
pass
class AxisTree(Object):
"""AxisTree node
Parameters
----------
axis_parent_map: Dict
A dictionary that maps Axis to parent axis name, value is None if there is not parent axis.
"""

axis_parent_map: Dict[Axis, Optional[str]]

def __init__(self, axis_parent_map) -> None:
keys = list(axis_parent_map.keys())
values = list(axis_parent_map.values())
self.__init_handle_by_constructor__(
_ffi_api.AxisTree, keys, values # type:ignore
)


@tvm._ffi.register_object("tir.sparse.SparseBuffer")
Expand All @@ -157,8 +171,8 @@ class SparseBuffer:
Parameters
----------
root : AxisTree
The root of the axis dependency tree of the sparse buffer
tree : AxisTree
The axis dependency tree of the sparse buffer
axes : List[Axis]
The axes of the sparse buffer
Expand All @@ -170,12 +184,12 @@ class SparseBuffer:
The data of the sparse buffer
"""

root: AxisTree
tree: AxisTree
axes: List[Axis]
ndim: int
data: Buffer

def __init__(self, root, axes, ndim, data):
def __init__(self, tree, axes, ndim, data):
self.__init_handle_by_constructor__(
_ffi_api.SparseBuffer, root, axes, ndim, data # type: ignore
)
77 changes: 65 additions & 12 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace tir {

namespace sparse {


// DenseFixedAxis
DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {
ObjectPtr<DenseFixedAxisNode> node = make_object<DenseFixedAxisNode>();
Expand All @@ -40,12 +41,14 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) {

TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});
TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis")
.set_body_typed([](String name, PrimExpr length) {
return DenseFixedAxis(name, length);
});

// DenseVariableAxis
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length,
Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -61,7 +64,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
});

// SparseFixedAxis
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
ObjectPtr<SparseFixedAxisNode> node = make_object<SparseFixedAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
Expand All @@ -73,14 +77,16 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, P
TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) {
.set_body_typed([](String name, PrimExpr length, Buffer indices,
PrimExpr num_cols) {
return SparseFixedAxis(name, length, indices, num_cols);
});

// SparseVariableAxis
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr,
Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node = make_object<SparseVariableAxisNode>();
SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices) {
ObjectPtr<SparseVariableAxisNode> node =
make_object<SparseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->indptr = std::move(indptr);
Expand All @@ -91,14 +97,61 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indp
TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
.set_body_typed([](String name, PrimExpr length, Buffer indptr,
Buffer indices) {
return SparseVariableAxis(name, length, indptr, indices);
});

// AxisTree
AxisTree::AxisTree(Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
CHECK_EQ(axes.size(), axis_parent_names.size())
<< "ValueError: The axes array should have the same length as axis_parent_names "
"array.";
ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
Axis root = Downcast<Axis>(RootAxis());
for (const Axis& axis : axes) {
// update axis map
String name = axis->name;
CHECK(node->axis_map.find(name) != node->axis_map.end()) << "ValueError: duplicate axis names.";
node->axis_map[name] = axis;
}
for (size_t i = 0; i < axes.size(); i++) {
// update parent map & children map
Axis axis = axes[i];
Optional<String> parent_name = axis_parent_names[i];
if (parent_name.get() != nullptr) {
CHECK(node->axis_map.find(parent_name.value()) != node->axis_map.end())
<< "ValueError: Parent axis name doesn't exist.";
}
Axis parent_axis = (parent_name.get() != nullptr)
? node->axis_map[parent_name.value()]
: root;
node->parent[axis] = parent_axis;
if (node->children.find(parent_axis) != node->children.end()) {
node->children[parent_axis].push_back(axis);
} else {
Array<Axis> children;
children.push_back(axis);
node->children[parent_axis] = std::move(children);
}
}
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(AxisTreeNode);

TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
.set_body_typed([](Array<Axis> axes,
Array<Optional<String>> axis_parent_names) {
return AxisTree(axes, axis_parent_names);
});

// SparseBuffer
SparseBuffer::SparseBuffer(AxisTree root, Array<Axis> axes, int ndim, Buffer data) {
SparseBuffer::SparseBuffer(AxisTree tree, Array<Axis> axes, int ndim,
Buffer data) {
ObjectPtr<SparseBufferNode> node = make_object<SparseBufferNode>();
node->root = std::move(root);
node->tree = std::move(tree);
node->axes = std::move(axes);
node->ndim = ndim;
node->data = std::move(data);
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_tir_sparse_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.tir as tir

def test_format_tree_creation():
i = tir.sparse.DenseFixedAxis('i', 128)
j = tir.sparse.DenseFixedAxis('j', 128)
k = tir.sparse.DenseFixedAxis('k', 128)
tree = tir.sparse.AxisTree({
i: None,
j: None,
k: None
})
print(tree)


if __name__ == "__main__":
test_format_tree_creation()

0 comments on commit 8c0c440

Please sign in to comment.