Skip to content

Commit

Permalink
Fatal bugfix and change the signature of DenseVariableAxis. (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Dec 14, 2021
1 parent 77562c0 commit 37bdbfb
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 28 deletions.
16 changes: 11 additions & 5 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class AxisNode : public Object {
DataType GetIndexType() const { return length->dtype; }

virtual AxisKind kind() const = 0;
virtual PrimExpr nnz() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -134,6 +135,8 @@ class DenseFixedAxisNode : public DenseAxisNode {
public:
AxisKind kind() const final { return AxisKind::kDenseFixed; }

PrimExpr nnz() const final { return length; }

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -234,6 +237,7 @@ class FusedAxis : public DenseFixedAxis {
class DenseVariableAxisNode : public DenseAxisNode {
public:
Buffer indptr;
PrimExpr nnz_;

void VisitAttrs(AttrVisitor* v) {
DenseAxisNode::VisitAttrs(v);
Expand All @@ -249,10 +253,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
hash_reduce(indptr);
}

PrimExpr nnz() const { return indptr->shape[0]; }

AxisKind kind() const final { return AxisKind::kDenseVariable; }

PrimExpr nnz() const final { return nnz_; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand All @@ -263,7 +267,7 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};
Expand All @@ -289,11 +293,13 @@ class SparseFixedAxisNode : public SparseAxisNode {
}

void SHashReduce(SHashReducer hash_reduce) const {
SparseFixedAxisNode::SHashReduce(hash_reduce);
SparseAxisNode::SHashReduce(hash_reduce);
hash_reduce(indices);
hash_reduce(nnz_cols);
}

PrimExpr nnz() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseFixed; }

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
Expand Down Expand Up @@ -336,7 +342,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

PrimExpr nnz() const { return indptr->shape[0]; }
PrimExpr nnz() const { return indices->shape[0]; }

AxisKind kind() const final { return AxisKind::kSparseVariable; }

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,11 +885,11 @@ def dense_variable(
f"`dense_variable` expected assign to only one var, but got {names}", span
)

length, indptr_len = shape
length, indptr_len, nnz = shape
indptr_buf = tvm.tir.decl_buffer(
(indptr_len,), dtype=idtype, name=names[0] + "_indptr", span=span
)
axis = DenseVariableAxis(names[0], length, indptr_buf)
axis = DenseVariableAxis(names[0], length, nnz, indptr_buf)
self.context.sp_struct.append(axis)
self.context.sp_struct_params.append([indptr_var])
self.context.update_symbol(names[0], axis, self.node)
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,12 @@ class DenseVariableAxis(DenseAxis):

name: str
length: PrimExpr
nnz: PrimExpr
indptr: Buffer

def __init__(self, name, length, indptr):
def __init__(self, name, length, nnz, indptr):
self.__init_handle_by_constructor__(
_ffi_api.DenseVariableAxis, name, length, indptr # type: ignore
_ffi_api.DenseVariableAxis, name, length, nnz, indptr # type: ignore
)


Expand Down
29 changes: 10 additions & 19 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
/******** DenseVariableAxis ********/

/*! \brief Default constuctor of DenseVariableAxis */
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) {
DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
ObjectPtr<DenseVariableAxisNode> node = make_object<DenseVariableAxisNode>();
node->name = std::move(name);
node->length = std::move(length);
node->nnz_ = std::move(nnz);
node->indptr = std::move(indptr);
data_ = std::move(node);
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr) {
return DenseVariableAxis(name, length, indptr);
.set_body_typed([](String name, PrimExpr length, PrimExpr nnz, Buffer indptr) {
return DenseVariableAxis(std::move(name), std::move(length), std::move(nnz), std::move(indptr));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -128,17 +129,7 @@ FusedAxis::FusedAxis(Array<Axis> group, int index) {
fused_name += group[i]->name;
}
node->name = "fused_" + fused_name + "_" + group[index]->name;

if (const auto* df_axis = group[index].as<DenseFixedAxisNode>()) {
node->length = df_axis->length;
} else if (const auto* sf_axis = group[index].as<SparseFixedAxisNode>()) {
// TODO(zihao): accumulate previous dimensions.
} else if (const auto* dv_axis = group[index].as<DenseVariableAxisNode>()) {
node->length = dv_axis->nnz();
} else if (const auto* sv_axis = group[index].as<SparseVariableAxisNode>()) {
node->length = sv_axis->nnz();
}

node->length = group[index]->nnz();
node->is_derived_axis = true;
node->group = std::move(group);
node->index = index;
Expand Down Expand Up @@ -183,7 +174,7 @@ TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) {
return SparseFixedAxis(name, length, indices, nnz_cols);
return SparseFixedAxis(std::move(name), std::move(length), std::move(indices), std::move(nnz_cols));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -210,7 +201,7 @@ TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis")
.set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) {
return SparseVariableAxis(name, length, indptr, indices);
return SparseVariableAxis(std::move(name), std::move(length), std::move(indptr), std::move(indices));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -259,7 +250,7 @@ TVM_REGISTER_NODE_TYPE(AxisTreeNode);

TVM_REGISTER_GLOBAL("tir.sparse.AxisTree")
.set_body_typed([](Array<String> axis_names, Array<Optional<String>> axis_parent_names) {
return AxisTree(axis_names, axis_parent_names);
return AxisTree(std::move(axis_names), std::move(axis_parent_names));
});

/******** SparseBuffer ********/
Expand All @@ -279,7 +270,7 @@ TVM_REGISTER_NODE_TYPE(SparseBufferNode);

TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer")
.set_body_typed([](Array<Axis> axes, Buffer data, String name) {
return SparseBuffer(axes, data, name);
return SparseBuffer(std::move(axes), std::move(data), std::move(name));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down Expand Up @@ -338,7 +329,7 @@ TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
return SpIterVar(var, max_extent, is_reduction, axis);
return SpIterVar(std::move(var), std::move(max_extent), is_reduction, std::move(axis));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down

0 comments on commit 37bdbfb

Please sign in to comment.