Skip to content

Commit

Permalink
Finish index lowering (maybe)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Nov 1, 2021
1 parent d648e99 commit deb46df
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 82 deletions.
15 changes: 14 additions & 1 deletion src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file sparse.cc
* \brief buffers and formats in sparse tir.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/sparse.h>
Expand Down Expand Up @@ -158,15 +159,27 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_
Optional<Axis> axis) {
ObjectPtr<SpIterVarNode> node = make_object<SpIterVarNode>();

arith::Analyzer ana;
if (axis.defined()) {
CHECK(ana.CanProveEqual(axis.value()->length, max_extent));
}
if (kind != SpIterKind::kDenseFixed) {
CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must "
"specify the axis over which the SpIterVar iterates";
const char* err_str = "ValueError: The given kind doesn't match the type of the given axis";
if (kind == SpIterKind::kDenseVariable) {
CHECK(axis.value()->IsInstance<DenseFixedAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseFixed) {
CHECK(axis.value()->IsInstance<SparseFixedAxisNode>()) << err_str;
} else if (kind == SpIterKind::kSparseVariable) {
CHECK(axis.value()->IsInstance<SparseVariableAxisNode>()) << err_str;
}
}

node->var = Var(std::move(name));
node->max_extent = std::move(max_extent);
node->kind = kind;
node->is_reduction = is_reduction;
node->is_reduction = is_reduction;
node->axis = std::move(axis);
data_ = std::move(node);
}
Expand Down
286 changes: 205 additions & 81 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,132 +33,256 @@
namespace tvm {
namespace tir {

class SparseTIRLowerer : public StmtExprMutator {
/*!
* \brief Check whether a given SparseBuffer contains the given axis.
* \brief buffer The SparseBuffer to be checked
* \brief axis The axis to be checked
* \return A boolean indicating whether the given SparseBuffer contains the given axis
*/
bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) {
for (int i = 0; i < static_cast<int>(buffer->axes.size()); ++i) {
if (buffer->axes[i].same_as(axis)) {
return true;
}
}
return false;
}

using BufferAccessMap = Map<SparseBuffer, Array<SpIterVar>>;
using DependencyMap =
std::unordered_map<SpIterVar, std::pair<SparseBuffer, int>, ObjectPtrHash, ObjectPtrEqual>;

/*
* \brief For each sparse-fixed or sparse-variable iterator, collect the iterators that it depends
* on.
*/
class AccessAndDependencyCollector : public StmtExprVisitor {
public:
void Collect(Stmt stmt) {
VisitStmt(std::move(stmt));

for (const std::pair<SparseBuffer, Array<SpIterVar>>& kv_pair : buffer_access_map_) {
const SparseBuffer& buffer = kv_pair.first;
int ndim = static_cast<int>(kv_pair.second.size());
for (int k = 0; k < ndim; ++k) {
const SpIterVar& sp_iter = kv_pair.second[k];
if (sp_iter->kind == SpIterKind::kDenseFixed ||
sp_iter->kind == SpIterKind::kDenseVariable ||
!BufferContainsAxis(buffer, sp_iter->axis.value())) {
continue;
}

ICHECK(dependency_map_.count(sp_iter) == 0);
dependency_map_[sp_iter] = std::make_pair(buffer, k);
}
}
}

BufferAccessMap buffer_access_map_;
DependencyMap dependency_map_;

private:
void AddAccessPattern(const SparseBuffer& buffer, const Array<PrimExpr>& indices) {
int ndim = buffer->ndim();
CHECK_EQ(static_cast<int>(indices.size()), ndim);

Array<SpIterVar> iters;
iters.reserve(ndim);
for (int i = 0; i < ndim; ++i) {
const SpIterVarNode* sp_iter = indices[i].as<SpIterVarNode>();
CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar";
iters.push_back(GetRef<SpIterVar>(sp_iter));
}

BufferAccessMap::iterator it = buffer_access_map_.find(buffer);
if (it == buffer_access_map_.end()) {
buffer_access_map_.Set(buffer, iters);
} else {
ICHECK_EQ(static_cast<int>((*it).second.size()), ndim);
for (int i = 0; i < ndim; ++i) {
CHECK((*it).second[i].same_as(iters[i]))
<< "ValueError: Currently all accesses to a same buffer are required to be the same";
}
}
}

void VisitStmt_(const SparseBufferStoreNode* store) final {
ExprVisitor::VisitExpr(store->value);
AddAccessPattern(store->buffer, store->indices);
}

void VisitExpr_(const SparseBufferLoadNode* load) final {
AddAccessPattern(load->buffer, load->indices);
}
};

class IndexTransformer : public StmtExprMutator {
public:
explicit IndexTransformer(BufferAccessMap buffer_access_map, DependencyMap dependency_map)
: buffer_access_map_(std::move(buffer_access_map)),
dependency_map_(std::move(dependency_map)) {}

private:
std::pair<Buffer, PrimExpr> LowerIndices(SparseBuffer sp_buffer, Array<PrimExpr> indices) {
PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array<PrimExpr>& indices) {
int ndim = sp_buffer->ndim();
ICHECK_EQ(static_cast<int>(indices.size()), ndim);
int n_lower = static_cast<int>(indices.size());
ICHECK_LE(n_lower, ndim);

PrimExpr lowered_index = Integer(0);

for (int i = 0; i < ndim; ++i) {
for (int i = 0; i < n_lower; ++i) {
const Axis& axis = sp_buffer->axes[i];
const PrimExpr& index = indices[i];

// Stage 1.
// Stage 1. Get the sparse index.
const auto* sp_iter = index.as<SpIterVarNode>();
PrimExpr sp_index{nullptr};
if (const auto* sp_iter = index.as<SpIterVarNode>()) {
SpIterKind kind = sp_iter->kind;
if (kind == SpIterKind::kDenseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
sp_index = GetRef<SpIterVar>(sp_iter);
} else {
PrimExpr l = LowerIndex(lowered_index, sp_buffer, i, 0);
PrimExpr r = LowerIndex(Add(lowered_index, 1), sp_buffer, i, 0);
Var buffer_var;
if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
buffer_var = sf_axis->indices->data;
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
buffer_var = sv_axis->indices->data;
} else {
LOG(FATAL) << "Cannot reach here";
}
sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r));
}
} else if (kind == SpIterKind::kDenseVariable) {
const auto* dv_axis = axis.as<DenseVariableAxisNode>();
CHECK(dv_axis != nullptr);
CHECK(sp_iter->axis.defined());
CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar";

PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0);
PrimExpr r = AccumulateLowerIndex(Add(lowered_index, 1), sp_buffer, i, 0);

SpIterKind kind = sp_iter->kind;
if (kind == SpIterKind::kDenseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
sp_index = GetRef<SpIterVar>(sp_iter);
} else if (kind == SpIterKind::kSparseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
CHECK(sp_iter->axis.defined());
const Axis& iterated_axis = sp_iter->axis.value();
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
// Todo: convert to dense
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
if (iterated_axis.get() == sf_axis) {
sp_index = GetRef<SpIterVar>(sp_iter);
} else {
// Todo: convert to dense and do binary search
}
} else {
Var buffer_var;
if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
buffer_var = sf_axis->indices->data;
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
// Todo: convert to dense and do binary search
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
buffer_var = sv_axis->indices->data;
} else {
LOG(FATAL) << "Cannot reach here";
}
} else {
CHECK(kind == SpIterKind::kSparseVariable);
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
CHECK(sp_iter->axis.defined());
const Axis& iterated_axis = sp_iter->axis.value();
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, df_axis->length));
// Todo: convert to dense
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sf_axis->length));
// Todo: convert to dense and do binary search
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
CHECK(ana.CanProveEqual(sp_iter->max_extent, sv_axis->length));
if (iterated_axis.get() == sv_axis) {
sp_index = GetRef<SpIterVar>(sp_iter);
} else {
// Todo: convert to dense and do binary search
}
sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r));
}
} else if (kind == SpIterKind::kDenseVariable) {
const auto* dv_axis = axis.as<DenseVariableAxisNode>();
CHECK(dv_axis != nullptr);
CHECK(sp_iter->axis.defined());
sp_index = GetRef<SpIterVar>(sp_iter);
} else if (kind == SpIterKind::kSparseFixed) {
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
CHECK(sp_iter->axis.defined());
const Axis& iterated_axis = sp_iter->axis.value();
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
sp_index = GetDenseValue(sp_iter);
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
if (iterated_axis.get() == sf_axis) {
sp_index = GetRef<SpIterVar>(sp_iter);
} else {
LOG(FATAL) << "Cannot reach here";
sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
std::move(r));
}
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
std::move(r));
} else {
LOG(FATAL) << "Cannot reach here";
}
} else {
// Todo
CHECK(kind == SpIterKind::kSparseVariable);
CHECK(!axis->IsInstance<DenseVariableAxisNode>());
CHECK(sp_iter->axis.defined());
const Axis& iterated_axis = sp_iter->axis.value();
if (const auto* df_axis = axis.as<DenseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length));
sp_index = GetDenseValue(sp_iter);
} else if (const auto* sf_axis = axis.as<SparseFixedAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length));
sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
std::move(r));
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length));
if (iterated_axis.get() == sv_axis) {
sp_index = GetRef<SpIterVar>(sp_iter);
} else {
sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l),
std::move(r));
}
} else {
LOG(FATAL) << "Cannot reach here";
}
}

// Stage 2.
lowered_index = LowerIndex(std::move(lowered_index), sp_buffer, i, sp_index);
// Stage 2. Accumulate the lowered index.
lowered_index =
AccumulateLowerIndex(std::move(lowered_index), sp_buffer, i, std::move(sp_index));
}

return std::make_pair(sp_buffer->data, lowered_index);
return lowered_index;
}

PrimExpr LowerIndex(PrimExpr prev_lowered_index, SparseBuffer sp_buffer, int dim,
PrimExpr index) {
PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim,
PrimExpr index) {
const Axis& axis = sp_buffer->axes[dim];
if (axis->IsInstance<DenseFixedAxisNode>() || axis->IsInstance<SparseFixedAxisNode>()) {
return ana.Simplify(prev_lowered_index * axis->length + index);
return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index));
} else if (const auto* dv_axis = axis.as<DenseVariableAxisNode>()) {
return ana.Simplify(Add(BufferLoad(dv_axis->indptr, {prev_lowered_index}), index));
return ana_.Simplify(
Add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index)));
} else if (const auto* sv_axis = axis.as<SparseVariableAxisNode>()) {
return ana.Simplify(Add(BufferLoad(sv_axis->indptr, {prev_lowered_index}), index));
return ana_.Simplify(
Add(BufferLoad(sv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index)));
}
LOG(FATAL) << "Cannot reach here";
throw;
}

PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) {
SpIterKind kind = sp_iter->kind;
CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable);
Axis iterated_axis = sp_iter->axis.value();

std::pair<SparseBuffer, int> depended_pair = dependency_map_[GetRef<SpIterVar>(sp_iter)];
Array<SpIterVar> buffer_access_iters = buffer_access_map_[depended_pair.first];
int n_depended = depended_pair.second;

Array<PrimExpr> depended_iters{buffer_access_iters.begin(),
buffer_access_iters.begin() + n_depended};
PrimExpr lowered_indices = LowerIndices(depended_pair.first, depended_iters);

if (kind == SpIterKind::kSparseFixed) {
return BufferLoad(Downcast<SparseFixedAxis>(iterated_axis)->indices,
{std::move(lowered_indices)});
} else {
return BufferLoad(Downcast<SparseVariableAxis>(iterated_axis)->indices,
{std::move(lowered_indices)});
}
}

PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final {
std::pair<Buffer, PrimExpr> res = LowerIndices(load->buffer, load->indices);
return BufferLoad(std::move(res.first), {std::move(res.second)});
PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices);
return BufferLoad(load->buffer->data, {std::move(lowered_indices)});
}

Stmt VisitStmt_(const SparseBufferStoreNode* store) final {
PrimExpr value = ExprMutator::VisitExpr(store->value);
std::pair<Buffer, PrimExpr> res = LowerIndices(store->buffer, store->indices);
return BufferStore(std::move(res.first), std::move(value), {std::move(res.second)});
PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices);
return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)});
}

arith::Analyzer ana;
BufferAccessMap buffer_access_map_;
DependencyMap dependency_map_;
arith::Analyzer ana_;
};

PrimFunc LowerSparseTIR(PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = SparseTIRLowerer()(std::move(f->body));
AccessAndDependencyCollector collector;
collector.Collect(f->body);
fptr->body = IndexTransformer(collector.buffer_access_map_,
collector.dependency_map_)(std::move(f->body));
return f;
} else {
return f;
Expand Down

0 comments on commit deb46df

Please sign in to comment.