Skip to content

Commit

Permalink
[SCAN] Enable fix point analysis for scan
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 17, 2017
1 parent d114dfc commit 98e830b
Show file tree
Hide file tree
Showing 12 changed files with 523 additions and 91 deletions.
3 changes: 2 additions & 1 deletion include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ enum AttachType : int {
kNone = 0,
kRoot = 1,
kInline = 2,
kScope = 3
kScope = 3,
kScanUpdate = 4
};

/*! \brief IterVar type */
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
virtual Array<Expr> output_shape(size_t i) const = 0;

static constexpr const char* _type_key = "Operation";
};

// Implementations of inline functions
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def build(sch,
arg_list.append(x)
else:
raise ValueError("args must be Tensor, Buffer or Var")
# lowering
# normalize schedule first
sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.StorageFlatten(stmt, binds)
Expand Down
3 changes: 3 additions & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS2(ScheduleOps);

} // namespace schedule
Expand Down
96 changes: 36 additions & 60 deletions src/schedule/bound.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

/*!
* Copyright (c) 2016 by Contributors
* \file bound.cc
Expand Down Expand Up @@ -277,10 +278,12 @@ void BoundProp(const Operation& op,
}
}


// Given the bound of output of op
// Pass the bound to the related axis in op.
void GatherOpBound(const ScanOpNode* scan,
const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
CHECK(!rmap->count(scan->scan_axis));
Expand All @@ -299,21 +302,29 @@ void GatherOpBound(const ScanOpNode* scan,
Range r = arith::Union(time_dom).cover_range(sdom);
(*rmap)[scan->scan_axis] = Range::make_with_min_extent(
sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
Array<Operation> body = ScanGetBody_(scan, fg);
Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(op, body);
// Update for spatial axis.
size_t sp_idx = 0;
for (size_t i = 0; i < output.size(); ++i) {
const TensorDom& d = tmap.at(output[i]);
for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
IterVar sp_ax = scan->spatial_axis_[sp_idx];
CHECK(!rmap->count(sp_ax));
// In default, we always need all spatial axis
// Unless that axis only refers back to itself as a fixed point.
// TODO(tqchen): Add fix point detection.
(*rmap)[sp_ax] = sp_ax->dom;
CHECK(fix_pt.count(sp_ax));
if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
// fix point, we can slice it.
(*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
} else {
// not a fix point, need to include everything.
(*rmap)[sp_ax] = sp_ax->dom;
}
}
}
}

void GatherOpBound(const Operation& op,
const FeedGraph& fg,
const std::unordered_map<Tensor, TensorDom>& tmap,
std::unordered_map<IterVar, Range>* rmap) {
if (op.as<ComputeOpNode>()) {
Expand All @@ -329,7 +340,7 @@ void GatherOpBound(const Operation& op,
(*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
}
} else if (op.as<ScanOpNode>()) {
GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
GatherOpBound(op.as<ScanOpNode>(), op, fg, tmap, rmap);
} else if (op.as<PlaceholderOpNode>()) {
// dp nothing
} else {
Expand All @@ -347,31 +358,26 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
}

// The map beteen tensor and operation it feeds ti
using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;

// AttachPath maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
using AttachPath = Map<Operation, Array<IterVar> >;


void InferRootBound(const Stage& stage,
const FeedGraph& feed_graph,
const AttachPath& attach_path,
std::unordered_map<IterVar, Range>* rmap) {
if (stage->attach_type == kInline) return;
if (stage->attach_type == kRoot || stage->attach_type == kNone) {
if (stage->is_output ||
stage->attach_type == kRoot ||
stage->attach_type == kNone) {
for (auto iv : OutputRelatedIterVars(stage->op)) {
CHECK(iv->dom.defined());
CHECK(!rmap->count(iv));
(*rmap)[iv] = iv->dom;
}
return;
}
// Infer root bounds for the attached node.
CHECK_EQ(stage->attach_type, kScope);
Stage parent = stage->attach_stage;
CHECK(parent.defined());
// parent stage, if any
Stage parent;
if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
parent = stage->attach_stage;
}

// The tensor domain.
std::unordered_map<Tensor, TensorDom> tmap;
Expand All @@ -385,7 +391,7 @@ void InferRootBound(const Stage& stage,
auto it = feed_graph.find(t);
if (it != feed_graph.end()) {
for (const Operation& op : it->second) {
if (op != parent->op) {
if (!parent.defined() || op != parent->op) {
consumers.insert(op);
} else {
direct_consume_by_parent = true;
Expand All @@ -406,14 +412,17 @@ void InferRootBound(const Stage& stage,
}

if (direct_consume_by_parent) {
// parent stage if exist
Stage parent = stage->attach_stage;
// Bound inference logics in parent.
std::unordered_map<IterVar, IntSet> up_state;
bool fix_value = true;
for (auto iv : parent->leaf_iter_vars) {
Range vrange = rmap->at(iv);
CHECK(is_zero(vrange->min))
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
<< " call schedule.normalize to achieve this. "
<< " stage=" << parent;
// special optimization to remove trivial loop
if (is_one(vrange->extent)) {
up_state[iv] = IntSet::single_point(vrange->min);
Expand Down Expand Up @@ -464,8 +473,10 @@ void InferRootBound(const Stage& stage,
for (const Operation& op : consumers) {
std::unordered_map<const Variable*, IntSet> dom_map;
bool found = false;
Array<IterVar> attach = attach_path.at(stage->op);

for (IterVar iv : attach_path.at(op)) {
if (iv == stage->attach_ivar) {
if (attach.size() != 0 && iv == attach[0]) {
found = true; break;
}
Range vrange = rmap->at(iv);
Expand All @@ -474,7 +485,7 @@ void InferRootBound(const Stage& stage,
<< "call schedule.normalize to achieve this.";
relax_set[iv->var.get()] = IntSet::range(vrange);
}
CHECK(found)
CHECK(found || attach.size() == 0)
<< "Invalid Schedule, cannot find the producer " << stage->op
<< " along the loop nest specified by compute_at of consumer " << op;
for (auto iv : OutputRelatedIterVars(op)) {
Expand All @@ -483,50 +494,15 @@ void InferRootBound(const Stage& stage,
}
BoundProp(op, dom_map, &tmap);
}
GatherOpBound(stage->op, tmap, rmap);
GatherOpBound(stage->op, feed_graph, tmap, rmap);
}

FeedGraph CreateFeedGraph(const Schedule& sch) {
Map<IterVar, Range> InferBound(const Schedule& sch) {
Array<Operation> roots;
for (Operation op : sch->outputs) {
roots.push_back(sch->stage_map[op]->op);
}
auto g = CreateReadGraph(roots);
FeedGraph fg;
for (auto kv : g) {
for (Tensor t : kv.second) {
fg[t].push_back(kv.first);
}
}
return fg;
}

// Create AttachPath that maps op-> a list of IterVar
// That represents the loop nest op sits in from inner most to outermost
AttachPath CreateAttachPath(const Schedule& sch) {
AttachPath ret;
for (Stage stage : sch->stages) {
Array<IterVar> path;
for (Stage s = stage; s->attach_type == kScope;) {
IterVar attach_ivar = s->attach_ivar;
s = s->attach_stage;
bool start_attach = false;
for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (iv == attach_ivar) start_attach = true;
if (start_attach) path.push_back(iv);
}
CHECK(start_attach)
<< "Invalid Schedule: cannot find attach point " << attach_ivar
<< " in the schedule of " << s->op;
}
ret.Set(stage->op, path);
}
return ret;
}

Map<IterVar, Range> InferBound(const Schedule& sch) {
FeedGraph feed_graph = CreateFeedGraph(sch);
FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots));
AttachPath attach_path = CreateAttachPath(sch);

std::unordered_map<IterVar, Range> ret;
Expand Down
Loading

0 comments on commit 98e830b

Please sign in to comment.