Skip to content

Commit

Permalink
Fix sparse reorder after refactor (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jan 17, 2022
1 parent 1c0d940 commit 48e8488
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 252 deletions.
48 changes: 46 additions & 2 deletions src/tir/schedule/primitive/sparse_loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,58 @@ void CheckValidInputIterators(const ScheduleState self, const Array<SpIterVar>&
}
}

/*!
* \brief Check whether the sparse reorder would break dependency between iterators.
* \param new_order The new iterator order to be checked.
* \throw ScheduleError If the sparse reorder breaks dependency.
*/
void CheckDependency(const ScheduleState self, const Array<SpIterVar>& new_order) {
class DependencyError : public ScheduleError {
public:
explicit DependencyError(IRModule mod, SpIterVar iter, Array<SpIterVar> new_order):
mod_(std::move(mod)), iter_(std::move(iter)), new_order_(std::move(new_order)) {}

String FastErrorString() const final {
return "ScheduleError: the sparse reorder breaks dependency between axes.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
os << "ScheduleError: in new order " << new_order_
<< " iterator " << iter_ << " was placed before its dependent iterator.";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
SpIterVar iter_;
Array<SpIterVar> new_order_;
};

std::set<Axis> axes_set;
for (const SpIterVar& sp_iter : new_order) {
Axis axis = sp_iter->axis;
auto try_parent = axis->GetParentAxis();
if (try_parent.defined()) {
Axis parent = try_parent.value();
if (axes_set.find(parent) == axes_set.end()) {
throw DependencyError(self->mod, sp_iter, new_order);
}
}
axes_set.insert(axis);
}
}


SparseBlock SparseReorder(ScheduleState self, const SparseBlock& block,
const Array<SpIterVar>& new_order) {
// Step 1. Check whether the iterators in `new_order` are the same as `block`'s iterators.
CheckValidInputIterators(self, new_order, block->sp_iter_vars);

// Step 2. Check whether the new order does not break the iterator dependency.
// TODO(zihao): rewrite this part.
// CheckDependency(self, block, new_order);
CheckDependency(self, new_order);

// Step 3. Create the new SparseBlock.
ObjectPtr<SparseBlockNode> p_new_block = make_object<SparseBlockNode>(*block.get());
Expand Down
17 changes: 11 additions & 6 deletions tests/python/sparsetir/test_tir_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,16 @@ def rgcn(
F_in = T.dense_fixed(feat_size)
F_out = T.dense_fixed(feat_size)
E = T.match_sparse_buffer(etype, (I, J), "int32")
W = T.match_sparse_buffer(w, (R, F_in, F_out), "float32")
W = T.match_sparse_buffer(w, (R, F_out, F_in), "float32")
X = T.match_sparse_buffer(x, (T.dense(J), F_in), "float32")
Y = T.match_sparse_buffer(y, (I, F_out), "float32")
T.func_attr({"global_symbol": "main", "tir.noalias": True})
with T.iter([I, F_out, J, F_in], "SSRR", "rgcn-forward") as [
vi, vout, vj, vin,
]:
with T.init():
Y[vi, vout] = 0.
Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vin, vout] * X[vj, vin]
Y[vi, vout] = Y[vi, vout] + W[E[vi, vj], vout, vin] * X[vj, vin]


@T.prim_func
Expand Down Expand Up @@ -179,15 +180,19 @@ def msg_func(edges):
print("dgl high-mem:\t\t", accum / (total - cold_start))

# tir
N, R, FEAT_SIZE, NNZ = lowered_rgcn.params[-4:]
mod = tvm.IRModule.from_expr(rgcn)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_rgcn, True)

N, R, FEAT_SIZE, NNZ = mod["main"].params[-4:]
sch = tir.Schedule(
lowered_rgcn.specialize(
mod["main"].specialize(
{N: g.number_of_nodes(), R: g.num_rels, FEAT_SIZE: feat_size, NNZ: g.number_of_edges()}
)
)

outer = sch.get_block("rgcn-forward_0")
inner = sch.get_block("rgcn-forward_1")
outer = sch.get_block("rgcn-forward0")
inner = sch.get_block("rgcn-forward1")
i, f_out = sch.get_loops(outer)
j, f_in = sch.get_loops(inner)
sch.bind(i, "blockIdx.x")
Expand Down
Loading

0 comments on commit 48e8488

Please sign in to comment.