From 98e830b3b3571015c705a8520275186b30dee70c Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 15 Feb 2017 22:46:22 -0800 Subject: [PATCH] [SCAN] Enable fix point analysis for scan --- include/tvm/schedule.h | 3 +- include/tvm/tensor.h | 2 + python/tvm/build.py | 3 +- src/api/api_schedule.cc | 3 + src/schedule/bound.cc | 96 ++---- src/schedule/graph.cc | 311 +++++++++++++++++- src/schedule/graph.h | 54 +++ src/schedule/schedule_lang.cc | 9 +- src/schedule/schedule_ops.cc | 12 +- .../unittest/test_schedule_bound_inference.py | 17 - tests/python/unittest/test_schedule_graph.py | 101 ++++++ .../unittest/test_schedule_schedule_ops.py | 3 +- 12 files changed, 523 insertions(+), 91 deletions(-) create mode 100644 tests/python/unittest/test_schedule_graph.py diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 18407567744a3..2fdc3bf4978f1 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -26,7 +26,8 @@ enum AttachType : int { kNone = 0, kRoot = 1, kInline = 2, - kScope = 3 + kScope = 3, + kScanUpdate = 4 }; /*! \brief IterVar type */ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 92786b33106df..11766cd005d5f 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode { virtual Type output_dtype(size_t i) const = 0; /*! \return shape of i-th output */ virtual Array output_shape(size_t i) const = 0; + + static constexpr const char* _type_key = "Operation"; }; // Implementations of inline functions diff --git a/python/tvm/build.py b/python/tvm/build.py index 40cb92b458aa7..764db0ae53049 100644 --- a/python/tvm/build.py +++ b/python/tvm/build.py @@ -63,7 +63,8 @@ def build(sch, arg_list.append(x) else: raise ValueError("args must be Tensor, Buffer or Var") - # lowering + # normalize schedule first + sch.normalize() bounds = schedule.InferBound(sch) stmt = schedule.ScheduleOps(sch, bounds) stmt = ir_pass.StorageFlatten(stmt, binds) diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 882ff94bde212..d953e37e23539 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise) REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS2(PostDFSOrder); +REGISTER_SCHEDULE_PASS1(ScanGetBody); +REGISTER_SCHEDULE_PASS1(CreateAttachPath); +REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS2(ScheduleOps); } // namespace schedule diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc index 4724d97627a74..cdf69d1acc4f2 100644 --- a/src/schedule/bound.cc +++ b/src/schedule/bound.cc @@ -1,3 +1,4 @@ + /*! * Copyright (c) 2016 by Contributors * \file bound.cc @@ -277,10 +278,12 @@ void BoundProp(const Operation& op, } } + // Given the bound of output of op // Pass the bound to the related axis in op. void GatherOpBound(const ScanOpNode* scan, const Operation& op, + const FeedGraph& fg, const std::unordered_map& tmap, std::unordered_map* rmap) { CHECK(!rmap->count(scan->scan_axis)); @@ -299,21 +302,29 @@ void GatherOpBound(const ScanOpNode* scan, Range r = arith::Union(time_dom).cover_range(sdom); (*rmap)[scan->scan_axis] = Range::make_with_min_extent( sdom->min, ir::Simplify(r->extent + r->min - sdom->min)); + Array body = ScanGetBody_(scan, fg); + Map fix_pt = ScanFixPointAnalysis(op, body); // Update for spatial axis. size_t sp_idx = 0; for (size_t i = 0; i < output.size(); ++i) { + const TensorDom& d = tmap.at(output[i]); for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { IterVar sp_ax = scan->spatial_axis_[sp_idx]; CHECK(!rmap->count(sp_ax)); - // In default, we always need all spatial axis - // Unless that axis only refers back to itself as a fixed point. - // TODO(tqchen): Add fix point detection. - (*rmap)[sp_ax] = sp_ax->dom; + CHECK(fix_pt.count(sp_ax)); + if (fix_pt[sp_ax].as()->value) { + // fix point, we can slice it. + (*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom); + } else { + // not a fix point, need to include everything. + (*rmap)[sp_ax] = sp_ax->dom; + } } } } void GatherOpBound(const Operation& op, + const FeedGraph& fg, const std::unordered_map& tmap, std::unordered_map* rmap) { if (op.as()) { @@ -329,7 +340,7 @@ void GatherOpBound(const Operation& op, (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom; } } else if (op.as()) { - GatherOpBound(op.as(), op, tmap, rmap); + GatherOpBound(op.as(), op, fg, tmap, rmap); } else if (op.as()) { // dp nothing } else { @@ -347,20 +358,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) { return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank; } -// The map beteen tensor and operation it feeds ti -using FeedGraph = std::unordered_map >; - -// AttachPath maps op-> a list of IterVar -// That represents the loop nest op sits in from inner most to outermost -using AttachPath = Map >; - - void InferRootBound(const Stage& stage, const FeedGraph& feed_graph, const AttachPath& attach_path, std::unordered_map* rmap) { if (stage->attach_type == kInline) return; - if (stage->attach_type == kRoot || stage->attach_type == kNone) { + if (stage->is_output || + stage->attach_type == kRoot || + stage->attach_type == kNone) { for (auto iv : OutputRelatedIterVars(stage->op)) { CHECK(iv->dom.defined()); CHECK(!rmap->count(iv)); @@ -368,10 +373,11 @@ void InferRootBound(const Stage& stage, } return; } - // Infer root bounds for the attached node. - CHECK_EQ(stage->attach_type, kScope); - Stage parent = stage->attach_stage; - CHECK(parent.defined()); + // parent stage, if any + Stage parent; + if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) { + parent = stage->attach_stage; + } // The tensor domain. std::unordered_map tmap; @@ -385,7 +391,7 @@ void InferRootBound(const Stage& stage, auto it = feed_graph.find(t); if (it != feed_graph.end()) { for (const Operation& op : it->second) { - if (op != parent->op) { + if (!parent.defined() || op != parent->op) { consumers.insert(op); } else { direct_consume_by_parent = true; @@ -406,6 +412,8 @@ void InferRootBound(const Stage& stage, } if (direct_consume_by_parent) { + // parent stage if exist + Stage parent = stage->attach_stage; // Bound inference logics in parent. std::unordered_map up_state; bool fix_value = true; @@ -413,7 +421,8 @@ void InferRootBound(const Stage& stage, Range vrange = rmap->at(iv); CHECK(is_zero(vrange->min)) << "InferBound requires every leaf iter var's min equals 0, " - << "call schedule.normalize to achieve this."; + << " call schedule.normalize to achieve this. " + << " stage=" << parent; // special optimization to remove trivial loop if (is_one(vrange->extent)) { up_state[iv] = IntSet::single_point(vrange->min); @@ -464,8 +473,10 @@ void InferRootBound(const Stage& stage, for (const Operation& op : consumers) { std::unordered_map dom_map; bool found = false; + Array attach = attach_path.at(stage->op); + for (IterVar iv : attach_path.at(op)) { - if (iv == stage->attach_ivar) { + if (attach.size() != 0 && iv == attach[0]) { found = true; break; } Range vrange = rmap->at(iv); @@ -474,7 +485,7 @@ void InferRootBound(const Stage& stage, << "call schedule.normalize to achieve this."; relax_set[iv->var.get()] = IntSet::range(vrange); } - CHECK(found) + CHECK(found || attach.size() == 0) << "Invalid Schedule, cannot find the producer " << stage->op << " along the loop nest specified by compute_at of consumer " << op; for (auto iv : OutputRelatedIterVars(op)) { @@ -483,50 +494,15 @@ void InferRootBound(const Stage& stage, } BoundProp(op, dom_map, &tmap); } - GatherOpBound(stage->op, tmap, rmap); + GatherOpBound(stage->op, feed_graph, tmap, rmap); } -FeedGraph CreateFeedGraph(const Schedule& sch) { +Map InferBound(const Schedule& sch) { Array roots; for (Operation op : sch->outputs) { roots.push_back(sch->stage_map[op]->op); } - auto g = CreateReadGraph(roots); - FeedGraph fg; - for (auto kv : g) { - for (Tensor t : kv.second) { - fg[t].push_back(kv.first); - } - } - return fg; -} - -// Create AttachPath that maps op-> a list of IterVar -// That represents the loop nest op sits in from inner most to outermost -AttachPath CreateAttachPath(const Schedule& sch) { - AttachPath ret; - for (Stage stage : sch->stages) { - Array path; - for (Stage s = stage; s->attach_type == kScope;) { - IterVar attach_ivar = s->attach_ivar; - s = s->attach_stage; - bool start_attach = false; - for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { - IterVar iv = s->leaf_iter_vars[i - 1]; - if (iv == attach_ivar) start_attach = true; - if (start_attach) path.push_back(iv); - } - CHECK(start_attach) - << "Invalid Schedule: cannot find attach point " << attach_ivar - << " in the schedule of " << s->op; - } - ret.Set(stage->op, path); - } - return ret; -} - -Map InferBound(const Schedule& sch) { - FeedGraph feed_graph = CreateFeedGraph(sch); + FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots)); AttachPath attach_path = CreateAttachPath(sch); std::unordered_map ret; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index f1047bf95ac90..66a6d1c8c98d6 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -8,6 +8,46 @@ #include #include "./graph.h" +namespace tvm { +namespace schedule { +// key to specific tensor dimension. +struct TensorDimKey { + FunctionRef f; + int value_index; + int dim; + TensorDimKey() {} + TensorDimKey(const ir::Call* op, int dim) + : f(op->func), value_index(op->value_index), dim(dim) { + } + TensorDimKey(const Tensor& t, int dim) + : f(t->op), value_index(t->value_index), dim(dim) { + } + inline bool operator==(const TensorDimKey& other) const { + return f == other.f && + value_index == other.value_index && + dim == other.dim; + } + inline bool operator!=(const TensorDimKey& other) const { + return !operator==(other); + } +}; +} // namespace schedule +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::schedule::TensorDimKey> { + std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const { + size_t lhs = k.f.hash(); + size_t rhs = static_cast(k.value_index) << 32UL | + static_cast(k.dim); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std + + namespace tvm { namespace schedule { @@ -28,7 +68,7 @@ ReadGraph CreateReadGraph(const Array& roots) { stack.pop_back(); Array deps; if (op.as()) { - auto fvisit = [&deps, &visited, &stack](const NodeRef& n) { + auto fvisit = [&deps](const NodeRef& n) { auto *call = n.as(); if (call != nullptr && call->func.defined()) { Operation call_op(call->func.node_); @@ -59,7 +99,6 @@ ReadGraph CreateReadGraph(const Array& roots) { return rmap; } - void PostDFSOrder(const Operation& op, const ReadGraph& g, std::unordered_set* visited, @@ -83,5 +122,273 @@ Array PostDFSOrder( return post_order; } +FeedGraph CreateFeedGraph(const ReadGraph& g) { + FeedGraph fg; + for (auto kv : g) { + for (Tensor t : kv.second) { + fg[t].push_back(kv.first); + } + } + return fg; +} + +AttachPath CreateAttachPath(Schedule sch) { + AttachPath ret; + for (size_t i = sch->stages.size(); i != 0; --i) { + Stage stage = sch->stages[i - 1]; + // mark scan attach. + if (stage->op.as()) { + const ScanOpNode* scan = stage->op.as(); + for (Tensor t : scan->update) { + Stage ustage = sch->stage_map[t->op]; + CHECK(ustage->attach_type == kNone || + ustage->attach_type == kInline || + ustage->attach_type == kScanUpdate) + << "Cannot specify compute_at for scan's init/update"; + ustage->attach_type = kScanUpdate; + ustage->attach_ivar = stage->leaf_iter_vars[stage->leaf_iter_vars.size() - 1]; + ustage->attach_stage = stage; + } + } + Array path; + + for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) { + IterVar attach_ivar = s->attach_ivar; + s = s->attach_stage; + bool start_attach = false; + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (iv == attach_ivar) start_attach = true; + if (start_attach) path.push_back(iv); + } + CHECK(start_attach) + << "Invalid Schedule: cannot find attach point " << attach_ivar + << " in the schedule of " << s->op; + } + + if (!ret.count(stage->op)) { + ret.Set(stage->op, path); + } + } + return ret; +} + +// graph of push reach relation of tensor dimensions +using ReachGraph = std::unordered_map >; + +ReachGraph GetReachGraph(const Array& ops) { + ReachGraph reach; + std::unordered_set bset; + for (size_t i = 0; i < ops.size(); ++i) { + bset.insert(ops[i].get()); + } + + for (Operation op : ops) { + if (op.as()) { + const auto& update = op.as()->update; + const auto& init = op.as()->init; + for (size_t i = 0; i < update.size(); ++i) { + Tensor t = op.output(i); + for (size_t k = 0; k < update[i]->shape.size(); ++k) { + reach[TensorDimKey(t, k + 1)].emplace_back( + TensorDimKey(update[i], k)); + reach[TensorDimKey(t, k + 1)].emplace_back( + TensorDimKey(init[i], k + 1)); + } + } + } else if (op.as()) { + std::unordered_map vmap; + const auto& axis = op.as()->axis; + Tensor t = op.output(0); + for (size_t i = 0; i < axis.size(); ++i) { + vmap[axis[i]->var.get()] = TensorDimKey(t, i); + reach[TensorDimKey(t, i)] = {}; + } + auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + if (!bset.count(call->func.get())) return; + for (size_t i = 0; i < call->args.size(); ++i) { + TensorDimKey dkey(call, i); + auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) { + const Variable *v = node.as(); + auto it = vmap.find(v); + if (it != vmap.end()) { + reach[it->second].push_back(dkey); + } + }; + ir::PostOrderVisit(call->args[i], fpush); + } + } + }; + ir::PostOrderVisit(op.as()->body, fvisit); + } + } + return reach; +} + +// Get all the operations that forms body of scan +void ScanGetBodyPostDFS_( + Operation op, + const ScanOpNode* scan, + const FeedGraph& feed_graph, + std::unordered_set* visited, + Array* result) { + if (op.get() == scan) return; + bool empty_feed = true; + for (int i = 0; i < op->num_outputs(); ++i) { + auto it = feed_graph.find(op.output(i)); + if (it != feed_graph.end() && it->second.size()) { + empty_feed = false; + for (const Operation& xop : it->second) { + if (visited->count(xop.get())) continue; + visited->insert(xop.get()); + ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result); + result->push_back(xop); + } + } + } + if (empty_feed && op.get() != scan) { + LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan"; + } +} + +Array ScanGetBody_( + const ScanOpNode* scan, + const FeedGraph& feed_graph) { + CHECK(scan != nullptr); + std::unordered_set visited; + Array result; + for (Tensor t : scan->state_placeholder) { + ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result); + } + return result; +} + +Array ScanGetBody(const Operation& scan) { + return ScanGetBody_(scan.as(), + CreateFeedGraph(CreateReadGraph({scan}))); +} + +Map ScanFixPointAnalysis( + const Operation& scan_op, const Array& body) { + const ScanOpNode* scan = scan_op.as(); + CHECK(body[0].get() == scan); + + std::unordered_map exact_reach; + std::unordered_set fail_set; + + for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { + for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + TensorDimKey key(scan->state_placeholder[i], k + 1); + exact_reach[key] = scan->spatial_axis_[sp_idx].get(); + } + } + // merge exact reach + auto f_merge_key = [&exact_reach, &fail_set]( + const TensorDimKey& dst, const TensorDimKey& src) { + auto sit = exact_reach.find(src); + if (sit == exact_reach.end()) return; + auto dit = exact_reach.find(dst); + if (dit == exact_reach.end()) { + exact_reach[dst] = sit->second; + } else { + if (dit->second != sit->second) { + fail_set.insert(dit->second); + fail_set.insert(sit->second); + } + } + }; + // prop exact reach back. + for (size_t i = body.size(); i != 1; --i) { + const Operation& op = body[i - 1]; + if (op.as()) { + const auto& update = op.as()->update; + const auto& init = op.as()->init; + for (size_t i = 0; i < update.size(); ++i) { + Tensor t = op.output(i); + for (size_t k = 0; i < update[i]->shape.size(); ++k) { + f_merge_key(TensorDimKey(t, k + 1), TensorDimKey(update[i], k)); + f_merge_key(TensorDimKey(t, k + 1), TensorDimKey(init[i], k + 1)); + } + } + } else if (op.as()) { + std::unordered_map vmap; + const auto& axis = op.as()->axis; + Tensor t = op.output(0); + for (size_t i = 0; i < axis.size(); ++i) { + vmap[axis[i]->var.get()] = TensorDimKey(t, i); + } + auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( + const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + for (size_t i = 0; i < call->args.size(); ++i) { + auto it = vmap.find(call->args[i].get()); + TensorDimKey src(call, i); + if (it != vmap.end()) { + f_merge_key(it->second, src); + } else { + if (exact_reach.count(src)) { + fail_set.insert(exact_reach.at(src)); + } + } + } + } + }; + ir::PostOrderVisit(op.as()->body, fvisit); + } + } + ReachGraph reach; + Map ret; + std::unordered_set place_holder_ref; + for (size_t i = 0; i < scan->state_placeholder.size(); ++i) { + for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) { + place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k)); + } + } + + for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) { + for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) { + TensorDimKey key(scan->update[i], k); + TensorDimKey target(scan->state_placeholder[i], k + 1); + IterVar sp_iv = scan->spatial_axis_[sp_idx]; + if (fail_set.count(sp_iv.get()) || + !exact_reach.count(key) || + exact_reach.at(key) != sp_iv.get()) { + ret.Set(sp_iv, make_const(Int(32), 0)); + } else { + // now we proved exact match, need to prove no interference with other graph. + if (reach.size() == 0) reach = GetReachGraph(body); + // do a DFS + std::unordered_set visited; + std::vector stack{key}; + visited.insert(key); + while (!stack.empty()) { + TensorDimKey k = stack.back(); + if (k != target && place_holder_ref.count(k)) break; + stack.pop_back(); + if (!reach.count(k)) { + LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim; + } + + for (TensorDimKey kk : reach.at(k)) { + if (visited.count(kk)) continue; + visited.insert(kk); + stack.push_back(kk); + } + } + if (!stack.empty()) { + // failed the prove. + ret.Set(sp_iv, make_const(Int(32), 0)); + } else { + ret.Set(sp_iv, make_const(Int(32), 1)); + } + } + } + } + return ret; +} + } // namespace schedule } // namespace tvm diff --git a/src/schedule/graph.h b/src/schedule/graph.h index 5a40c8e4ce0fb..4b4b2df6e747d 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -9,6 +9,7 @@ #include #include #include +#include #include namespace tvm { @@ -19,6 +20,16 @@ namespace schedule { */ using ReadGraph = Map >; +/*! + * \brief The map beteen tensor and operation it feeds to + */ +using FeedGraph = std::unordered_map >; + +/*! + * \brief AttachPath maps op-> a list of IterVar + */ +using AttachPath = Map >; + /*! * \brief Get read graph of each operation to all the * Tensors that it directly depends on. @@ -41,6 +52,49 @@ ReadGraph CreateReadGraph(const Array& roots); Array PostDFSOrder( const Array& roots, const ReadGraph& g); +/*! + * \brief Create feedgraph for given Schedule + * \param g The read graph. + * \return The created feedgraph. + */ +FeedGraph CreateFeedGraph(const ReadGraph& g); + +/*! + * \brief Create AttachPath that maps op-> a list of IterVar + * That represents the loop nest op sits in from inner most to outermost + * Also inserts attach_stage for scan updates when needed. + * + * \param sch The schedule. + * \return The attach path. + */ +AttachPath CreateAttachPath(Schedule sch); + +/*! + * \brief Get all operations inside the recursion of scan. + * \param scan The scan node. + * \param feed_graph The feed graph to help analysis. + * \return The body operations, in read dependency order. + */ +Array ScanGetBody_( + const ScanOpNode* scan, const FeedGraph& feed_graph); +// same as ScanGetBody_, but create FeedGraph internally. +Array ScanGetBody(const Operation& scan); + +/*! + * \brief Analyze each spatial dimension of scan's result. + * Give check on whether each dimension is fix point, + * An axis is a fixed point if it only refers back to itself in recursion + * and it is not used in axis of other recursion field. + * + * next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...] + * + * \param scan The scan node. + * \param body The body of scan, sorted in reverse PostDFSOrder. + * \return Map of spatial_axis -> IntImm + */ +Map ScanFixPointAnalysis( + const Operation& scan, const Array& body); + } // namespace schedule } // namespace tvm diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index b18ae28e54754..1df868be7a8b7 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -93,7 +93,9 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) { // NOLINT(*) } } CHECK(found) - << "Cannot find the axis in parent's leaf_iter_vars or outermost_threads"; + << "Cannot find the axis " << scope + << " in parent's leaf_iter_vars or outermost_threads:" + << " parent=" << parent; return *this; } @@ -278,18 +280,19 @@ void Schedule::normalize() { std::unordered_map rebase_map; std::unordered_map attach_mark; - for (Stage s : (*this)->stages) { if (s->attach_type == kScope) { attach_mark[s->attach_stage.get()] = 1; } + if (s->op.as()) { + attach_mark[s.get()] = 1; + } } for (Stage s : (*this)->stages) { if (!attach_mark.count(s.get())) continue; auto root_iter_vars = s->op->root_iter_vars(); ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite(); - for (IterVar iv : root_iter_vars) { size_t idx = FindNodeRef(leaf_vars, iv); if (idx < leaf_vars->data.size()) { diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index aa7c383635efa..8060f3785061e 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -614,7 +614,7 @@ class SchedulePostProc : public IRMutator { if (it->second.defined()) { Stmt ret = AttrStmt::make( it->second, op->type_key, op->value, op->body); - return this->Mutate_(ret.as(), ret); + return this->Mutate(ret); } else { return this->Mutate(op->body); } @@ -631,7 +631,7 @@ class SchedulePostProc : public IRMutator { Stmt ret = Realize::make( it->second->op, it->second->value_index, op->type, op->bounds, op->condition, op->body); - return this->Mutate_(ret.as(), ret); + return this->Mutate(ret); } else { return this->Mutate(op->body); } @@ -648,7 +648,7 @@ class SchedulePostProc : public IRMutator { Stmt ret = Provide::make( dst->op, dst->value_index, op->value, RewriteArgs(it->second.second, op->args)); - return IRMutator::Mutate_(ret.as(), ret); + return this->Mutate(ret); } else { return IRMutator::Mutate_(op, s); } @@ -664,7 +664,7 @@ class SchedulePostProc : public IRMutator { op->type, dst->op->name, RewriteArgs(it->second.second, op->args), op->call_type, dst->op, dst->value_index); - return IRMutator::Mutate_(ret.as(), ret); + return this->Mutate(ret); } } return IRMutator::Mutate_(op, e); @@ -758,7 +758,9 @@ Stmt ScheduleOps( // no need to specify place holder op. if (s->op.as()) continue; if (scan_attach.count(s->op)) { - CHECK(s->attach_type == kNone || s->attach_type == kInline) + CHECK(s->attach_type == kNone || + s->attach_type == kInline || + s->attach_type == kScanUpdate) << "Cannot specify compute_at for scan's init/update"; CHECK(body.defined()); const auto& p = scan_attach.at(s->op); diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index e80fb275c561c..0b1f6613e7b44 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -51,24 +51,7 @@ def test_bound3(): assert(bounds[A1.op.axis[1]].extent.value==16) -def test_create_read_graph(): - m = tvm.Var('m') - l = tvm.Var('l') - A = tvm.placeholder((m, l), name='A') - A1 = tvm.compute((m, l), lambda i, j: A[i, j]) - A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) - - g = tvm.schedule.CreateReadGraph([A2.op]) - - assert g[A2.op][0] == A1 - assert g[A1.op][0] == A - post_order = tvm.schedule.PostDFSOrder([A2.op], g) - assert(post_order[0] == A.op) - assert(post_order[1] == A1.op) - - if __name__ == "__main__": - test_create_read_graph() test_bound3() test_bound1() test_bound2() diff --git a/tests/python/unittest/test_schedule_graph.py b/tests/python/unittest/test_schedule_graph.py new file mode 100644 index 0000000000000..d2536e140c3fc --- /dev/null +++ b/tests/python/unittest/test_schedule_graph.py @@ -0,0 +1,101 @@ +import tvm + +def test_scan(): + m = tvm.Var("m") + n = tvm.Var("n") + t = tvm.IterVar((1, m), name="t") + x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((m, n)) + s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init") + x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans") + s_up1 = tvm.compute((n,), lambda i: s_state[t - 1, i] + 1, name="up1") + s_update = tvm.compute((n,), lambda i: s_up1[i] + x_trans[t, i], name="update") + s_scan = tvm.scan(t, s_init, s_update, s_state) + + def test_getbody(): + body = tvm.schedule.ScanGetBody(s_scan.op) + assert set(body) == set([s_scan.op, s_update.op, s_up1.op]) + + def test_attach_path(): + s = tvm.Schedule(s_scan.op) + s[x_trans].compute_at(s[s_update], s_update.op.axis[0]) + apath = tvm.schedule.CreateAttachPath(s) + assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis])) + assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis])) + + def test_fix_pt(): + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.spatial_axis_[0]].value != 0) + +def test_scan_fix_point(): + m = tvm.Var("m") + n = tvm.Var("n") + l = tvm.Var("l") + t = tvm.IterVar((1, l), name="t") + x = tvm.compute((l, m, n), lambda *i: tvm.const(1, "float32"), name="x") + s_state = tvm.placeholder((l, m, n)) + s_init = tvm.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init") + + def test_scan0(): + s_update = tvm.compute((m, n), lambda i, j: x[t, j, i] + s_state[t-1, i, j], name="update") + s_scan = tvm.scan(t, s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1) + + def test_scan1(): + s_update = tvm.compute((m, n), lambda i, j: x[t, j, i] + s_state[t-1, j, i], name="update") + s_scan = tvm.scan(t, s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + def test_scan3_not_exact_reach(): + s_h1 = tvm.compute((n, m), lambda j, i: s_state[t-1, i, j], name="h1") + s_h2 = tvm.compute((m, n), lambda i, j: s_state[t-1, i, 10] * 2, name="h1") + s_update = tvm.compute((m, n), lambda i, j: s_h1[j, i] + s_h2[i, j], name="update") + s_scan = tvm.scan(t, s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + def test_scan4_reach_other(): + s_h1 = tvm.compute((n, m), lambda j, i: s_state[t-1, j, j], name="h1") + s_h2 = tvm.compute((m, n), lambda i, j: s_state[t-1, i, j] * 2, name="h1") + s_update = tvm.compute((m, n), lambda i, j: s_h1[j, i] + s_h2[i, j], name="update") + s_scan = tvm.scan(t, s_init, s_update, s_state) + body = tvm.schedule.ScanGetBody(s_scan.op) + fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body) + assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0) + assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0) + + test_scan0() + test_scan1() + test_scan3_not_exact_reach() + test_scan4_reach_other() + +def test_create_read_graph(): + m = tvm.Var('m') + l = tvm.Var('l') + A = tvm.placeholder((m, l), name='A') + A1 = tvm.compute((m, l), lambda i, j: A[i, j]) + A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3) + + g = tvm.schedule.CreateReadGraph([A2.op]) + + assert g[A2.op][0] == A1 + assert g[A1.op][0] == A + post_order = tvm.schedule.PostDFSOrder([A2.op], g) + assert(post_order[0] == A.op) + assert(post_order[1] == A1.op) + + +if __name__ == "__main__": + test_scan() + test_nest_scan_getbody() + test_create_read_graph() + test_scan_fix_point() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 625bee5964141..fdcf4132a8f97 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -59,7 +59,6 @@ def test_schedule_scan(): stmt = tvm.schedule.ScheduleOps(s, bounds) print(stmt) - def test_auto_inline(): m = tvm.Var('m') n = tvm.Var('n') @@ -74,6 +73,7 @@ def test_auto_inline(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule_cache(): m = tvm.Var('m') n = tvm.Var('n') @@ -90,7 +90,6 @@ def test_schedule_cache(): if __name__ == "__main__": - test_schedule_scan() test_schedule0() test_schedule1() test_schedule2()