Skip to content

Commit

Permalink
Complete indices lowering (#30)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd

* done

* upd

* passed test

* upd
  • Loading branch information
yzh119 committed Jan 24, 2022
1 parent c23a156 commit 21f7058
Show file tree
Hide file tree
Showing 10 changed files with 415 additions and 383 deletions.
62 changes: 38 additions & 24 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
namespace tvm {
namespace tir {

enum class AxisKind : int {
kDenseFixed = 0,
kDenseVariable = 1,
kSparseFixed = 2,
kSparseVariable = 3
};

/*!
* \brief Base type for axis in sparse formats.
*/
Expand All @@ -49,6 +56,8 @@ class AxisNode : public Object {
DataType GetIndexType() const { return length->dtype; }

virtual bool is_fixed() const = 0;

virtual AxisKind kind() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -143,10 +152,14 @@ class DenseFixedAxisNode : public DenseAxisNode {
hash_reduce(from_sparse);
}

bool is_fixed() const {
bool is_fixed() const final{
return true;
}

AxisKind kind() const final {
return AxisKind::kDenseFixed;
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -183,10 +196,14 @@ class DenseVariableAxisNode : public DenseAxisNode {
hash_reduce(indptr);
}

bool is_fixed() const {
bool is_fixed() const final {
return false;
}

AxisKind kind() const final {
return AxisKind::kDenseVariable;
}

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -230,10 +247,14 @@ class SparseFixedAxisNode : public SparseAxisNode {
hash_reduce(nnz_cols);
}

bool is_fixed() const {
bool is_fixed() const final {
return true;
}

AxisKind kind() const final {
return AxisKind::kSparseFixed;
}

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand Down Expand Up @@ -276,10 +297,14 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

bool is_fixed() const {
bool is_fixed() const final {
return false;
}

AxisKind kind() const final {
return AxisKind::kSparseVariable;
}

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};
Expand Down Expand Up @@ -383,15 +408,9 @@ class SparseBuffer : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};

enum class SpIterKind : int {
kDenseFixed = 0,
kDenseVariable = 1,
kSparseFixed = 2,
kSparseVariable = 3
};

// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, SpIterKind kind);
TVM_DLL std::ostream& operator<<(std::ostream& os, AxisKind kind);

/*!
* \brief Iterator variables in SparseTIR
Expand All @@ -400,7 +419,6 @@ class SpIterVarNode : public Object {
public:
Var var;
PrimExpr max_extent;
SpIterKind kind;
bool is_reduction;
Axis axis;

Expand All @@ -409,21 +427,18 @@ class SpIterVarNode : public Object {
v->Visit("max_extent", &max_extent);
v->Visit("axis", &axis);
v->Visit("is_reduction", &is_reduction);
v->Visit("kind", &kind);
}

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
equal(axis, other->axis) && equal(is_reduction, other->is_reduction);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(var);
hash_reduce(max_extent);
hash_reduce(axis);
hash_reduce(is_reduction);
hash_reduce(kind);
}

static constexpr const char* _type_key = "tir.sparse.SpIterVar";
Expand All @@ -434,8 +449,7 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Axis axis);
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, bool is_reduction, Axis axis);

/*!
* \return the corresponding var in the IterVar.
Expand All @@ -449,18 +463,18 @@ class SpIterVar : public ObjectRef {
inline SpIterVar::operator PrimExpr() const { return (*this)->var; }

// inline implementations
inline const char* SpIterKind2String(SpIterKind t) {
inline const char* SpIterKind2String(AxisKind t) {
switch (t) {
case SpIterKind::kDenseFixed:
case AxisKind::kDenseFixed:
return "dense_fixed";
case SpIterKind::kDenseVariable:
case AxisKind::kDenseVariable:
return "dense_variable";
case SpIterKind::kSparseFixed:
case AxisKind::kSparseFixed:
return "sparse_fixed";
case SpIterKind::kSparseVariable:
case AxisKind::kSparseVariable:
return "sparse_variable";
}
LOG(FATAL) << "Unknown SpIterKind" << t;
LOG(FATAL) << "Unknown AxisKind" << t;
throw;
}

Expand Down
26 changes: 1 addition & 25 deletions python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,32 +261,8 @@ def comm_reducer(lambda_io, identities, span):


@register
def to_dense(axis: Axis, span: Optional[Span] = None):
def dense(axis: Axis, span: Optional[Span] = None):
if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)):
return DenseFixedAxis(axis.name + "_dense", axis.length, axis)
else:
return axis


@register
def cord(axis: Axis, span: Optional[Span] = None):
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
var_temp = tvm.te.var()
if isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)


@register
def pos(axis: Axis, span: Optional[Span] = None):
# The field `var` and `is_reduction` will be updated in SparseBlock scope handler
var_temp = tvm.te.var()
if isinstance(axis, DenseFixedAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False, axis)
elif isinstance(axis, DenseVariableAxis):
return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis)
elif isinstance(axis, SparseFixedAxis):
return SpIterVar(var_temp, axis.nnz_cols, SpIterVar.SparseFixed, False, axis)
else:
return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis)
19 changes: 9 additions & 10 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,38 +329,37 @@ class SparseBlock(WithScopeHandler):
"""With scope handler of SparseBlock"""

def __init__(self):
def iter(iters: List, iter_types: str, name: str = "", span: Optional[Span] = None):

def iter(axes: List, iter_types: str, name: str = "", span: Optional[Span] = None):
assert (
self.node and self.context and self.body
), "call 'exit_scope' before 'enter_scope'"
block_info = self.context.block_info_stack[-1]

if len(iters) != len(self.sp_iters):
if len(axes) != len(self.sp_iters):
self.context.report_error(
"Inconsistent number of sparse iteration variable names, "
+ f"there are {len(iters)} iterators but {len(self.sp_iters)} names. "
+ f"there are {len(axes)} iterators but {len(self.sp_iters)} names. "
+ "The number of sparse iteration variable names should match the number of iterators.",
self.node.span,
)
if len(iters) != len(iter_types):
if len(axes) != len(iter_types):
self.context.report_error(
"Inconsistent number of sparse iteration variable types, "
+ f"there are {len(iters)} iterators but {len(iter_types)} types. "
+ f"there are {len(axes)} iterators but {len(iter_types)} types. "
+ "The number of sparse iteration variable types should match the number of iterators.",
self.node.span,
)

sp_iters: List[SpIterVar] = []
for i, sp_iter in enumerate(iters):
assert isinstance(sp_iter, SpIterVar)
for i, axis in enumerate(axes):
is_reduction = True if iter_types[i] == "R" else False
sp_iters.append(
SpIterVar(
self.sp_iters[i],
sp_iter.max_extent,
sp_iter.kind,
axis.length,
is_reduction,
sp_iter.axis,
axis,
)
)

Expand Down
13 changes: 5 additions & 8 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.
"""SparseTIR axes and SparseBuffer
"""
from typing import List, Dict, Optional
from typing import Dict, List, Optional

import tvm._ffi
from tvm.ir import PrimExpr
from tvm.runtime import Object, const
from tvm.runtime import Object
from tvm.tir import Var

from . import _ffi_api
Expand Down Expand Up @@ -219,9 +220,6 @@ class SpIterVar(Object):
max_extent : PrimExpr
The maximum extent of the SpIterVar
kind : int
The kind of the SpIterVar
is_reduction : bool
Whether the SpIterVar is a reduction iterator
Expand All @@ -231,7 +229,6 @@ class SpIterVar(Object):

var: Var
max_extent: PrimExpr
kind: int
is_reduction: bool
axis: Axis

Expand All @@ -240,7 +237,7 @@ class SpIterVar(Object):
SparseFixed = 2
SparseVariable = 3

def __init__(self, var, max_extent, kind, is_reduction, axis):
def __init__(self, var, max_extent, is_reduction, axis):
self.__init_handle_by_constructor__(
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
_ffi_api.SpIterVar, var, max_extent, is_reduction, axis # type: ignore
)
9 changes: 3 additions & 6 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ Doc TVMScriptPrinter::AllocAxis(const Axis& axis) {
const auto* df_axis = axis.as<DenseFixedAxisNode>();

if (df_axis != nullptr && df_axis->from_sparse.defined()) {
val << tir_prefix_ << ".to_dense(" << Print(df_axis->from_sparse.value()) << ")";
val << tir_prefix_ << ".dense(" << Print(df_axis->from_sparse.value()) << ")";
} else {
std::string name = axis->name;
if (name.length() == 0 || !std::isalnum(name[0])) {
Expand Down Expand Up @@ -1329,11 +1329,8 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) {
for (int i = 0; i < n_iter; ++i) {
const SpIterVar& sp_iter = op->sp_iter_vars[i];
Doc iter_doc;
if (sp_iter->kind == SpIterKind::kDenseFixed || sp_iter->kind == SpIterKind::kDenseVariable) {
iter_doc << tir_prefix_ << ".cord(" << sp_iter->axis->name << ")";
} else {
iter_doc << tir_prefix_ << ".pos(" << sp_iter->axis->name << ")";
}
iter_doc << sp_iter->axis->name;
// TODO(zihao): fix expressions like T.dense(J)
var_not_in_headers_.insert(sp_iter->var.get());
sp_iter_docs.push_back(iter_doc);
sp_iter_name_docs.push_back(Print(sp_iter->var));
Expand Down
31 changes: 10 additions & 21 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,19 +211,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "], " << op->data << ")";
});

// SpIterKind
std::ostream& operator<<(std::ostream& out, SpIterKind type) {
// AxisKind
std::ostream& operator<<(std::ostream& out, AxisKind type) {
switch (type) {
case SpIterKind::kDenseFixed:
case AxisKind::kDenseFixed:
out << "dense-fixed";
break;
case SpIterKind::kDenseVariable:
case AxisKind::kDenseVariable:
out << "dense-variable";
break;
case SpIterKind::kSparseFixed:
case AxisKind::kSparseFixed:
out << "sparse-fixed";
break;
case SpIterKind::kSparseVariable:
case AxisKind::kSparseVariable:
out << "sparse-variable";
break;
default:
Expand All @@ -233,24 +233,13 @@ std::ostream& operator<<(std::ostream& out, SpIterKind type) {
}

// SpIterVar
SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Axis axis) {
SpIterVar::SpIterVar(Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

arith::Analyzer ana;
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
if (kind == SpIterKind::kDenseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>()) << err_str;
} else if (kind == SpIterKind::kDenseVariable) {
CHECK(axis->IsInstance<DenseVariableAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseFixed) {
CHECK(axis->IsInstance<SparseFixedAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseVariable) {
CHECK(axis->IsInstance<SparseVariableAxisNode>()) << err_str;
}

node->var = Var(std::move(var));
node->max_extent = std::move(max_extent);
node->kind = kind;
node->is_reduction = is_reduction;
node->axis = std::move(axis);
data_ = std::move(node);
Expand All @@ -259,15 +248,15 @@ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_redu
TVM_REGISTER_NODE_TYPE(SpIterVarNode);

TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar")
.set_body_typed([](Var var, PrimExpr max_extent, int kind, bool is_reduction, Axis axis) {
return SpIterVar(var, max_extent, SpIterKind(kind), is_reduction, axis);
.set_body_typed([](Var var, PrimExpr max_extent, bool is_reduction, Axis axis) {
return SpIterVar(var, max_extent, is_reduction, axis);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SpIterVarNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SpIterVarNode*>(node.get());
p->stream << "sp_iter_var(" << op->var->name_hint << ", " << op->max_extent << ", "
<< op->kind << ", " << (op->is_reduction ? "reduction" : "spatial") << ", "
<< (op->is_reduction ? "reduction" : "spatial") << ", "
<< op->axis->name << ")";
});

Expand Down
Loading

0 comments on commit 21f7058

Please sign in to comment.