From 44ccb6d94bdad4651c052903a8d6d6285064e85f Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 4 Nov 2021 03:36:44 +0800 Subject: [PATCH] [SparseTIR] Index Lowering (#8) * Add StmtFunctor/ExprFunctor for SparseBufferStore/Load * Add basic index lowering * Finish index lowering (maybe) * Address comments * Convert CRLF to LF --- include/tvm/tir/expr_functor.h | 4 + include/tvm/tir/stmt_functor.h | 4 + include/tvm/tir/transform.h | 6 + python/tvm/tir/transform/transform.py | 11 + src/tir/ir/expr_functor.cc | 14 ++ src/tir/ir/sparse.cc | 17 +- src/tir/ir/stmt_functor.cc | 19 ++ src/tir/transforms/lower_sparse_tir.cc | 306 +++++++++++++++++++++++++ 8 files changed, 379 insertions(+), 2 deletions(-) create mode 100644 src/tir/transforms/lower_sparse_tir.cc diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index b5f1d64a00c4..2507e734c7a7 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -119,6 +119,7 @@ class ExprFunctor { return VisitExpr_(static_cast(op), std::forward(args)...); } virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const SparseBufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ProducerLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; @@ -165,6 +166,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode); + IR_EXPR_FUNCTOR_DISPATCH(SparseBufferLoadNode); IR_EXPR_FUNCTOR_DISPATCH(ProducerLoadNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode); @@ -217,6 +219,7 @@ class TVM_DLL ExprVisitor : public ExprFunctor { void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; + void VisitExpr_(const SparseBufferLoadNode* op) override; void VisitExpr_(const ProducerLoadNode* op) override; void VisitExpr_(const LetNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -264,6 +267,7 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; + PrimExpr VisitExpr_(const SparseBufferLoadNode* op) override; PrimExpr VisitExpr_(const ProducerLoadNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 24773a5a471f..7185829f2b70 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -89,6 +89,7 @@ class StmtFunctor { virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const SparseBufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -121,6 +122,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); + IR_STMT_FUNCTOR_DISPATCH(SparseBufferStoreNode); IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); IR_STMT_FUNCTOR_DISPATCH(BlockNode); IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); @@ -157,6 +159,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const SparseBufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const ProducerStoreNode* op) override; @@ -257,6 +260,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const SparseBufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const ProducerStoreNode* op) override; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index e6b0af9773d9..2dd6024e8002 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -472,6 +472,12 @@ TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); */ TVM_DLL Pass ConvertForLoopsToSerial(); +/*! + * \brief Lower SparseTIR to TIR. + * \return The pass. + */ +TVM_DLL Pass LowerSparseTIR(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 722810e9aa5b..4912a4c2e728 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -726,3 +726,14 @@ def ConvertForLoopsToSerial(): The result pass """ return _ffi_api.ConvertForLoopsToSerial() # type: ignore + + +def LowerSparseTIR(): + """Lower SparseTIR to TIR + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerSparseTIR() # type: ignore diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 4c5ea5bfd2d0..b7e0665cf9fd 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -43,6 +43,10 @@ void ExprVisitor::VisitExpr_(const BufferLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void ExprVisitor::VisitExpr_(const SparseBufferLoadNode* op) { + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void ExprVisitor::VisitExpr_(const ProducerLoadNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } @@ -146,6 +150,16 @@ PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) { } } +PrimExpr ExprMutator::VisitExpr_(const SparseBufferLoadNode* op) { + auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; + Array indices = MutateArray(op->indices, fmutate); + if (indices.same_as(op->indices)) { + return GetRef(op); + } else { + return SparseBufferLoad(op->buffer, indices); + } +}; + PrimExpr ExprMutator::VisitExpr_(const ProducerLoadNode* op) { auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); }; Array indices = MutateArray(op->indices, fmutate); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 9154d96f818f..f9c9203ed369 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -21,6 +21,7 @@ * \file sparse.cc * \brief buffers and formats in sparse tir. */ +#include #include #include #include @@ -158,9 +159,21 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_ Optional axis) { ObjectPtr node = make_object(); + arith::Analyzer ana; + if (axis.defined()) { + CHECK(ana.CanProveEqual(axis.value()->length, max_extent)); + } if (kind != SpIterKind::kDenseFixed) { CHECK(axis.defined()) << "ValueError: To create a SpIterVar that is not fixed-dense, one must " "specify the axis over which the SpIterVar iterates"; + const char* err_str = "ValueError: The given kind doesn't match the type of the given axis"; + if (kind == SpIterKind::kDenseVariable) { + CHECK(axis.value()->IsInstance()) << err_str; + } else if (kind == SpIterKind::kSparseFixed) { + CHECK(axis.value()->IsInstance()) << err_str; + } else if (kind == SpIterKind::kSparseVariable) { + CHECK(axis.value()->IsInstance()) << err_str; + } } node->var = Var(std::move(name)); @@ -174,9 +187,9 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_ TVM_REGISTER_NODE_TYPE(SpIterVarNode); TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") - .set_body_typed([](String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction, + .set_body_typed([](String name, PrimExpr max_extent, int kind, bool is_reduction, Optional axis) { - return SpIterVar(name, max_extent, kind, is_reduction, axis); + return SpIterVar(name, max_extent, SpIterKind(kind), is_reduction, axis); }); } // namespace tir diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index d60ec72a7589..2a0c43904c70 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -69,6 +69,11 @@ void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); } +void StmtVisitor::VisitStmt_(const SparseBufferStoreNode* op) { + this->VisitExpr(op->value); + VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); +} + void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); @@ -367,6 +372,20 @@ Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { } } +Stmt StmtMutator::VisitStmt_(const SparseBufferStoreNode* op) { + PrimExpr value = this->VisitExpr(op->value); + Array indices = Internal::Mutate(this, op->indices); + + if (value.same_as(op->value) && indices.same_as(op->indices)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->value = std::move(value); + n->indices = std::move(indices); + return Stmt(n); + } +} + Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); PrimExpr condition = this->VisitExpr(op->condition); diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc new file mode 100644 index 000000000000..34d0b3d05e3b --- /dev/null +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_sparse_tir.cc + */ + +#include +#include +#include +#include + +#include + +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether a given SparseBuffer contains the given axis. + * \brief buffer The SparseBuffer to be checked + * \brief axis The axis to be checked + * \return A boolean indicating whether the given SparseBuffer contains the given axis + */ +bool BufferContainsAxis(const SparseBuffer& buffer, const Axis& axis) { + for (int i = 0; i < static_cast(buffer->axes.size()); ++i) { + if (buffer->axes[i].same_as(axis)) { + return true; + } + } + return false; +} + +using BufferAccessMap = Map>; +using DependencyMap = + std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; + +/*! + * \brief For each sparse-fixed or sparse-variable iterator, collect the iterators that it depends + * on. + */ +class AccessAndDependencyCollector : public StmtExprVisitor { + public: + void Collect(Stmt stmt) { + VisitStmt(std::move(stmt)); + + for (const std::pair>& kv_pair : buffer_access_map_) { + const SparseBuffer& buffer = kv_pair.first; + int ndim = static_cast(kv_pair.second.size()); + for (int k = 0; k < ndim; ++k) { + const SpIterVar& sp_iter = kv_pair.second[k]; + if (sp_iter->kind == SpIterKind::kDenseFixed || + sp_iter->kind == SpIterKind::kDenseVariable || + !BufferContainsAxis(buffer, sp_iter->axis.value())) { + continue; + } + + ICHECK(dependency_map_.count(sp_iter) == 0); + dependency_map_[sp_iter] = std::make_pair(buffer, k); + } + } + } + + BufferAccessMap buffer_access_map_; + DependencyMap dependency_map_; + + private: + void AddAccessPattern(const SparseBuffer& buffer, const Array& indices) { + int ndim = buffer->ndim(); + CHECK_EQ(static_cast(indices.size()), ndim); + + Array iters; + iters.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const SpIterVarNode* sp_iter = indices[i].as(); + CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; + iters.push_back(GetRef(sp_iter)); + } + + BufferAccessMap::iterator it = buffer_access_map_.find(buffer); + if (it == buffer_access_map_.end()) { + buffer_access_map_.Set(buffer, iters); + } else { + ICHECK_EQ(static_cast((*it).second.size()), ndim); + for (int i = 0; i < ndim; ++i) { + CHECK((*it).second[i].same_as(iters[i])) + << "ValueError: Currently all accesses to a same buffer are required to be the same"; + } + } + } + + void VisitStmt_(const SparseBufferStoreNode* store) final { + ExprVisitor::VisitExpr(store->value); + AddAccessPattern(store->buffer, store->indices); + } + + void VisitExpr_(const SparseBufferLoadNode* load) final { + AddAccessPattern(load->buffer, load->indices); + } +}; + +class IndexTransformer : public StmtExprMutator { + public: + explicit IndexTransformer(BufferAccessMap buffer_access_map, DependencyMap dependency_map) + : buffer_access_map_(std::move(buffer_access_map)), + dependency_map_(std::move(dependency_map)) {} + + private: + PrimExpr LowerIndices(SparseBuffer sp_buffer, const Array& indices) { + int ndim = sp_buffer->ndim(); + int n_lower = static_cast(indices.size()); + ICHECK_LE(n_lower, ndim); + + PrimExpr lowered_index = Integer(0); + + for (int i = 0; i < n_lower; ++i) { + const Axis& axis = sp_buffer->axes[i]; + const PrimExpr& index = indices[i]; + + // Stage 1. Get the sparse index. + const auto* sp_iter = index.as(); + PrimExpr sp_index{nullptr}; + CHECK(sp_iter) << "ValueError: Currently an index is only allowed to be SpIterVar"; + + PrimExpr l = AccumulateLowerIndex(lowered_index, sp_buffer, i, 0); + PrimExpr r = AccumulateLowerIndex(add(lowered_index, 1), sp_buffer, i, 0); + + SpIterKind kind = sp_iter->kind; + if (kind == SpIterKind::kDenseFixed) { + CHECK(!axis->IsInstance()); + if (const auto* df_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + sp_index = GetRef(sp_iter); + } else { + Var buffer_var; + if (const auto* sf_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); + buffer_var = sf_axis->indices->data; + } else if (const auto* sv_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); + buffer_var = sv_axis->indices->data; + } else { + LOG(FATAL) << "Cannot reach here"; + } + sp_index = lower_bound(buffer_var, index, std::move(l), std::move(r)); + } + } else if (kind == SpIterKind::kDenseVariable) { + const auto* dv_axis = axis.as(); + CHECK(dv_axis != nullptr); + CHECK(sp_iter->axis.defined()); + sp_index = GetRef(sp_iter); + } else if (kind == SpIterKind::kSparseFixed) { + CHECK(!axis->IsInstance()); + CHECK(sp_iter->axis.defined()); + const Axis& iterated_axis = sp_iter->axis.value(); + if (const auto* df_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + sp_index = GetDenseValue(sp_iter); + } else if (const auto* sf_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); + if (iterated_axis.get() == sf_axis) { + sp_index = GetRef(sp_iter); + } else { + sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } + } else if (const auto* sv_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } else { + LOG(FATAL) << "Cannot reach here"; + } + } else { + CHECK(kind == SpIterKind::kSparseVariable); + CHECK(!axis->IsInstance()); + CHECK(sp_iter->axis.defined()); + const Axis& iterated_axis = sp_iter->axis.value(); + if (const auto* df_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, df_axis->length)); + sp_index = GetDenseValue(sp_iter); + } else if (const auto* sf_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sf_axis->length)); + sp_index = lower_bound(sf_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } else if (const auto* sv_axis = axis.as()) { + CHECK(ana_.CanProveEqual(sp_iter->max_extent, sv_axis->length)); + if (iterated_axis.get() == sv_axis) { + sp_index = GetRef(sp_iter); + } else { + sp_index = lower_bound(sv_axis->indices->data, GetDenseValue(sp_iter), std::move(l), + std::move(r)); + } + } else { + LOG(FATAL) << "Cannot reach here"; + } + } + + // Stage 2. Accumulate the lowered index. + lowered_index = + AccumulateLowerIndex(std::move(lowered_index), sp_buffer, i, std::move(sp_index)); + } + + return lowered_index; + } + + PrimExpr AccumulateLowerIndex(PrimExpr prev_lowered_index, const SparseBuffer& sp_buffer, int dim, + PrimExpr index) { + const Axis& axis = sp_buffer->axes[dim]; + if (axis->IsInstance() || axis->IsInstance()) { + return ana_.Simplify(std::move(prev_lowered_index) * axis->length + std::move(index)); + } else if (const auto* dv_axis = axis.as()) { + return ana_.Simplify( + add(BufferLoad(dv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); + } else if (const auto* sv_axis = axis.as()) { + return ana_.Simplify( + add(BufferLoad(sv_axis->indptr, {std::move(prev_lowered_index)}), std::move(index))); + } + LOG(FATAL) << "Cannot reach here"; + throw; + } + + PrimExpr GetDenseValue(const SpIterVarNode* sp_iter) { + SpIterKind kind = sp_iter->kind; + CHECK(kind == SpIterKind::kSparseFixed || kind == SpIterKind::kSparseVariable); + Axis iterated_axis = sp_iter->axis.value(); + + std::pair dependent_pair = dependency_map_[GetRef(sp_iter)]; + Array buffer_access_iters = buffer_access_map_[dependent_pair.first]; + int n_dependent = dependent_pair.second; + + Array dependent_iters{buffer_access_iters.begin(), + buffer_access_iters.begin() + n_dependent}; + PrimExpr lowered_indices = LowerIndices(dependent_pair.first, dependent_iters); + + if (kind == SpIterKind::kSparseFixed) { + return BufferLoad(Downcast(iterated_axis)->indices, + {std::move(lowered_indices)}); + } else { + return BufferLoad(Downcast(iterated_axis)->indices, + {std::move(lowered_indices)}); + } + } + + PrimExpr VisitExpr_(const SparseBufferLoadNode* load) final { + PrimExpr lowered_indices = LowerIndices(load->buffer, load->indices); + return BufferLoad(load->buffer->data, {std::move(lowered_indices)}); + } + + Stmt VisitStmt_(const SparseBufferStoreNode* store) final { + PrimExpr value = ExprMutator::VisitExpr(store->value); + PrimExpr lowered_indices = LowerIndices(store->buffer, store->indices); + return BufferStore(store->buffer->data, std::move(value), {std::move(lowered_indices)}); + } + + BufferAccessMap buffer_access_map_; + DependencyMap dependency_map_; + arith::Analyzer ana_; +}; + +PrimFunc LowerSparseTIR(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + AccessAndDependencyCollector collector; + collector.Collect(f->body); + fptr->body = IndexTransformer(collector.buffer_access_map_, + collector.dependency_map_)(std::move(f->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass LowerSparseTIR() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LowerSparseTIR(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LowerSparseTIR").set_body_typed(LowerSparseTIR); + +} // namespace transform + +} // namespace tir +} // namespace tvm