Skip to content

Commit

Permalink
Axis Dependency Tree aware code-gen and bmm example (apache#28)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* remove redundancy

* fix

* upd

* upd
  • Loading branch information
yzh119 committed Nov 26, 2021
1 parent e2b64ef commit f25aa07
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 186 deletions.
22 changes: 20 additions & 2 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class AxisNode : public Object {
String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }

virtual bool is_fixed() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -141,6 +143,10 @@ class DenseFixedAxisNode : public DenseAxisNode {
hash_reduce(from_sparse);
}

bool is_fixed() const {
return true;
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -177,6 +183,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
hash_reduce(indptr);
}

bool is_fixed() const {
return false;
}

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};
Expand Down Expand Up @@ -220,6 +230,10 @@ class SparseFixedAxisNode : public SparseAxisNode {
hash_reduce(nnz_cols);
}

bool is_fixed() const {
return true;
}

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
};
Expand Down Expand Up @@ -262,6 +276,10 @@ class SparseVariableAxisNode : public SparseAxisNode {
hash_reduce(indices);
}

bool is_fixed() const {
return false;
}

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
};
Expand All @@ -283,9 +301,9 @@ class SparseVariableAxis : public SparseAxis {
class AxisTreeNode : public Object {
public:
// unordered map that stores the parent relationship between axes.
Map<String, Optional<String>> parent;
Map<String, String> parent;
// unordered map that stores the children relationship between axes.
Map<Optional<String>, Array<String>> children;
Map<String, Array<String>> children;

void VisitAttrs(AttrVisitor* v) {
v->Visit("parent", &parent);
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,10 @@ TVM_DLL Pass ConvertForLoopsToSerial();

/*!
* \brief Lower SparseTIR to TIR.
* \param axis_tree The axis dependency tree.
* \return The pass.
*/
TVM_DLL Pass LowerSparseTIR();
TVM_DLL Pass LowerSparseTIR(AxisTree axis_tree);

} // namespace transform
} // namespace tir
Expand Down
10 changes: 8 additions & 2 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Optional
from . import _ffi_api
from . import function_pass as _fpass
from ..sparse import AxisTree


def Apply(ftransform):
Expand Down Expand Up @@ -751,12 +752,17 @@ def ConvertForLoopsToSerial():
return _ffi_api.ConvertForLoopsToSerial() # type: ignore


def LowerSparseTIR():
def LowerSparseTIR(axis_tree: AxisTree):
"""Lower SparseTIR to TIR
Parameters
----------
axis_tree : AxisTree
The axis dependency tree.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerSparseTIR() # type: ignore
return _ffi_api.LowerSparseTIR(axis_tree) # type: ignore
9 changes: 6 additions & 3 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,15 @@ AxisTree::AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent
"axis_parent_names "
"array.";
ObjectPtr<AxisTreeNode> node = make_object<AxisTreeNode>();
Map<String, Optional<String>> parent;
Map<Optional<String>, Array<String>> children;
Map<String, String> parent;
Map<String, Array<String>> children;
for (size_t i = 0; i < axis_names.size(); i++) {
// update parent map & children map
String axis_name = axis_names[i];
Optional<String> parent_name = axis_parent_names[i];
String parent_name("root");
if (axis_parent_names[i].defined()) {
parent_name = axis_parent_names[i].value();
}
parent.Set(axis_name, parent_name);

auto it = children.find(parent_name);
Expand Down
149 changes: 118 additions & 31 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <set>
#include <stack>
#include <utility>

#include "../schedule/analysis.h"
Expand Down Expand Up @@ -87,8 +89,8 @@ Map<Var, Buffer> UpdateBufferMap(PrimFunc f) {
*/
class IndexTransformer : public StmtExprMutator {
public:
explicit IndexTransformer(AccessAndDependencyCollector collector)
: collector_(std::move(collector)) {}
explicit IndexTransformer(AccessAndDependencyCollector collector, AxisTree axis_tree)
: collector_(std::move(collector)), axis_tree_(std::move(axis_tree)) {}

private:
/*!
Expand Down Expand Up @@ -281,43 +283,124 @@ class IndexTransformer : public StmtExprMutator {
sp_block->init.defined() ? VisitStmt(sp_block->init.value()) : Optional<Stmt>(NullOpt);
Stmt body = VisitStmt(sp_block->body);

// Step 2. Create the new outer loop vars.
Array<Var> loop_vars;
// Step 2. Create the new loop vars.
std::unordered_map<const VarNode*, PrimExpr> var_map;
loop_vars.reserve(n_iter);
Array<Var> all_loop_vars;
var_map.reserve(n_iter);
for (const SpIterVar& sp_iter : sp_block->sp_iter_vars) {
Var loop_var("v_" + sp_iter->var->name_hint);
loop_vars.push_back(loop_var);
all_loop_vars.push_back(loop_var);
var_map[sp_iter->var.get()] = loop_var;
}

// Step 3. Create block iters and iter bindings.
Array<IterVar> block_iters;
Array<PrimExpr> iter_bindings;
block_iters.reserve(n_iter);
iter_bindings.reserve(n_iter);
for (int i = 0; i < n_iter; ++i) {
block_iters.push_back(SpIterVarToIterVar(sp_block->sp_iter_vars[i], var_map));
iter_bindings.push_back(loop_vars[i]);
}
// Step 3. Collet block iters and iter bindings.
std::set<String> in_stack;
in_stack.insert("root");
/* A stack that stores block itervars in each block. */
std::stack<Array<IterVar>> block_iters_st;
/* A stack that stores itervar bindings in each block. */
std::stack<Array<PrimExpr>> iter_bindings_st;
/* A stack that stores generated loop vars in each block. */
std::stack<Array<Var>> loop_vars_st;
/* A stack that stores whether to place init block in each block. */
std::stack<bool> place_init_st;
/* An indicator that records whether init block has been set. */
bool init_set = false;
do {
/* Block itervars of current block. */
Array<IterVar> block_iters;
/* Itervar bindings of current block. */
Array<PrimExpr> iter_bindings;
/* Axis names of current block. */
Array<String> axis_names;
/* Generated loop vars of current block. */
Array<Var> loop_vars;
/* An indicator that records whether there is reduction axis in current block. */
bool has_reduction_var = false;
for (int i = 0; i < n_iter; ++i) {
SpIterVar sp_it_var = sp_block->sp_iter_vars[i];
String axis_name = sp_it_var->axis->name;
auto&& parent_axis = axis_tree_->parent.Get(axis_name);
CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree.";
String parent_axis_name = parent_axis.value();
bool is_fixed_axis = sp_it_var->axis->is_fixed();
/* Add itervar to current block when
* - it's not used yet (not in stack) and
* - it's parent axis was used in outer blocks or
* - it's an iterator to a fixed axis.
*/
if ((is_fixed_axis || in_stack.find(parent_axis_name) != in_stack.end()) &&
in_stack.find(axis_name) == in_stack.end()) {
loop_vars.push_back(all_loop_vars[i]);
axis_names.push_back(std::move(axis_name));
block_iters.push_back(SpIterVarToIterVar(sp_it_var, var_map));
iter_bindings.push_back(all_loop_vars[i]);
has_reduction_var |= sp_it_var->is_reduction;
}
}

/* Tag axes in current block as "in-stack". */
for (const String&& axis_name : axis_names) {
in_stack.insert(std::move(axis_name));
}

/* Update stack. */
if (!block_iters.empty()) {
block_iters_st.push(std::move(block_iters));
iter_bindings_st.push(std::move(iter_bindings));
loop_vars_st.push(std::move(loop_vars));
if (init_set) {
place_init_st.push(false);
} else {
place_init_st.push(has_reduction_var);
init_set |= has_reduction_var;
}
} else {
break;
}
} while (true);

// Step 4. Generate the read-region and write-retion of the block.
Array<BufferRegion> reads{nullptr};
Array<BufferRegion> writes{nullptr};
GenerateReadWriteRegions(sp_block, &reads, &writes);

// Step 5. Create the block and block-realize
Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
Block block(block_iters, std::move(reads), std::move(writes), sp_block->name, std::move(body),
std::move(init), {}, {}, std::move(mapping));
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));

// Step 6. Create outer loops and the block binding.
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
// Step 5. Generate nested blocks and loops from innermost to outermost.
int blk_counter = 0;
while (!block_iters_st.empty()) {
Array<IterVar> block_iters = std::move(block_iters_st.top());
Array<PrimExpr> iter_bindings = std::move(iter_bindings_st.top());
Array<Var> loop_vars = std::move(loop_vars_st.top());
bool place_init = place_init_st.top();
block_iters_st.pop();
iter_bindings_st.pop();
loop_vars_st.pop();
place_init_st.pop();

Map<String, ObjectRef> mapping;
mapping.Set("sparse", Bool(true));
String blk_name_hint = sp_block->name;
if (blk_counter != 0) {
blk_name_hint = blk_name_hint + "_" + std::to_string(blk_counter);
}
Block block(/*iter_vars=*/block_iters,
/*reads=*/reads,
/*writes=*/writes,
/*name_hint=*/blk_name_hint,
/*body=*/std::move(body),
/*init=*/place_init ? std::move(init) : NullOpt,
/*alloc_buffers=*/{},
/*match_buffers=*/{},
/*annotations=*/std::move(mapping),
/*span=*/sp_block->span);
BlockRealize block_realize(std::move(iter_bindings), const_true(), std::move(block));
// Generate outer loop and the block binding.
Stmt loop = GenerateLoops(std::move(block_realize), block_iters, loop_vars);
body = loop;
blk_counter += 1;
}

return loop;
return body;
}

/*!
Expand Down Expand Up @@ -380,9 +463,10 @@ class IndexTransformer : public StmtExprMutator {
}

/*!
* \brief generated nested for loops for sparse block.
* \brief generated nested for-loops for sparse block.
* \param block_iters The iterators defined in sparse blocks.
* \param loop_vars The loop variables binded with block iterators.
* \return The outermost loop.
*/
Stmt GenerateLoops(Stmt body, const Array<IterVar>& block_iters, const Array<Var>& loop_vars) {
int n_iter = static_cast<int>(block_iters.size());
Expand All @@ -394,6 +478,7 @@ class IndexTransformer : public StmtExprMutator {
}

AccessAndDependencyCollector collector_;
AxisTree axis_tree_;
arith::Analyzer ana_;
std::unordered_set<const SparseBufferNode*> buffer_read_;
std::unordered_set<const SparseBufferNode*> buffer_write_;
Expand All @@ -411,11 +496,12 @@ Stmt WrapWithRootBlock(Stmt body) {
}

/*!
* \brief Rewrite the given primitive function
* \brief Rewrite the given primitive function.
* \param axis_tree The axis dependency tree.
* \param f The Sparse-TIR primitive function to lower.
* \return lowered primitive function in TIR.
*/
PrimFunc LowerSparseTIR(PrimFunc f) {
PrimFunc LowerSparseTIR(AxisTree axis_tree, PrimFunc f) {
// Only apply this pass to TIR that is not from TE schedules
if (!IsFromLegacyTESchedule(f)) {
PrimFuncNode* fptr = f.CopyOnWrite();
Expand All @@ -425,7 +511,7 @@ PrimFunc LowerSparseTIR(PrimFunc f) {
AccessAndDependencyCollector collector;
collector.Collect(f->body);
// Step 3. Lower indices.
fptr->body = IndexTransformer(collector)(std::move(f->body));
fptr->body = IndexTransformer(collector, axis_tree)(std::move(f->body));
// Step 4. Wrap the function body with a root block.
fptr->body = WrapWithRootBlock(std::move(fptr->body));
return f;
Expand All @@ -438,10 +524,11 @@ namespace transform {

/*!
* \brief The lowering pass from TIR to Sparse TIR.
* \param axis_tree The axis dependency tree.
*/
Pass LowerSparseTIR() {
Pass LowerSparseTIR(AxisTree axis_tree) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return LowerSparseTIR(std::move(f));
return LowerSparseTIR(std::move(axis_tree), std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerSparseTIR", {});
}
Expand Down
Loading

0 comments on commit f25aa07

Please sign in to comment.