Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#16 from Superjomn/fea/add-scheduler
Browse files Browse the repository at this point in the history
make scheduler works
  • Loading branch information
Superjomn committed Feb 5, 2020
2 parents 187ecf4 + 417dfb0 commit fc24b37
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 30 deletions.
5 changes: 5 additions & 0 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ class GraphNode : public Object {
static_assert(std::is_base_of<GraphNode, Derived>::value);
return static_cast<Derived*>(this);
}
template <typename Derived>
const Derived* As() const {
static_assert(std::is_base_of<GraphNode, Derived>::value);
return static_cast<const Derived*>(this);
}

//! Reset graph traversal meta info.
void ResetVisitMeta() { visited_time_ = 0; }
Expand Down
8 changes: 7 additions & 1 deletion cinn/poly/element.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@ void Element::InitSchedule() {
auto dims = GetDimNames(domain_);
auto dims_repr = utils::Join(dims, ", ");

auto repr = utils::StringFormat("{ %s[%s] -> [%s] }", id.c_str(), dims_repr.c_str(), dims_repr.c_str());
auto repr = utils::StringFormat("{ %s[%s] -> %s[%s] }", id.c_str(), dims_repr.c_str(), id.c_str(), dims_repr.c_str());
schedule_ = isl::map(domain_.ctx(), repr);

// set dimension names
for (int i = 0; i < dims.size(); i++) {
schedule_ = isl::manage(isl_map_set_dim_name(schedule_.release(), isl_dim_in, i, dims[i].c_str()));
schedule_ = isl::manage(isl_map_set_dim_name(schedule_.release(), isl_dim_out, i, dims[i].c_str()));
}
}

Element::Element(isl::set domain) : domain_(domain) {
Expand Down
85 changes: 75 additions & 10 deletions cinn/poly/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace poly {
struct ScheduleGraphNode : public common::GraphNode {
TimeSchedule time_schedule;

explicit ScheduleGraphNode(const std::vector<std::string> &dims) : time_schedule(dims) {}
explicit ScheduleGraphNode(const std::string &id, const std::vector<std::string> &dims) : time_schedule(id, dims) {}
};

struct ScheduleGraphEdge : public common::GraphEdge {
Expand All @@ -26,33 +26,73 @@ std::string TimeSchedule::__str__() const {
CHECK(!time_dims.empty());

// generate range: [dup, t0, t1...]
std::vector<std::string> range_dims({"dup"});
std::vector<std::string> range_dims;
for (int i = 0; i < time_dims.size(); i++) {
range_dims.push_back("t" + std::to_string(i));
range_dims.push_back("d" + std::to_string(i));
}

// generate conditions
std::vector<std::string> conds;
for (int i = 0; i < time_dims.size(); i++) {
conds.push_back(std::to_string(time_dims[i].time));
conds.push_back(time_dims[i].dim);
conds.push_back(utils::StringFormat("%s=%s", range_dims[2 * i].c_str(), std::to_string(time_dims[i].time).c_str()));
conds.push_back(utils::StringFormat("%s=%s", range_dims[2 * i + 1].c_str(), time_dims[i].dim.c_str()));
}

return utils::StringFormat("{ %s[%s] -> [%s]: %s",
id.c_str(),
return utils::StringFormat("{ %s[%s] -> [%s]: %s }",
id_.c_str(),
utils::Join(domain_dims, ", ").c_str(),
utils::Join(range_dims, ", ").c_str(),
utils::Join(conds, " and ").c_str());
}

TimeSchedule::TimeSchedule(const std::string &id, const std::vector<std::string> &dims) {
id_ = id;
domain_dims = dims;
for (auto &dim : domain_dims) {
time_dims.emplace_back(dim, 0);
}
}

void TimeSchedule::OrderAfter(const TimeSchedule &other, int level) {
CHECK_EQ(space_size(), other.space_size()) << "space not match";
CHECK_LT(level, other.space_size());
CHECK(!time_dims.empty());

for (int i = 0; i <= level; i++) {
this->time_dims[i].time = std::max(other.time_dims[i].time, this->time_dims[i].time);
}

this->time_dims[level].time++;
}

isl::map TimeSchedule::to_isl(isl::ctx ctx) const {
VLOG(3) << "isl: " << __str__();
return isl::map(ctx, __str__());
}

const std::string &TimeSchedule::id() const {
CHECK(!id_.empty());
return id_;
}

void Scheduler::RegisterElement(const Element &x) {
CHECK(!registration_finalized_) << "element registration has been finalized.";
space_size_ = std::max(space_size_, isl_map_dim(x.schedule().get(), isl_dim_out));
VLOG(3) << "space_size: " << space_size_;

// Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is
// like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range.
TimeSchedule schedule(GetDimNames(x.schedule(), isl_dim_out));
schedule_graph_.RegisterNode(x.id(), common::make_shared<ScheduleGraphNode>(GetDimNames(x.schedule(), isl_dim_out)));
auto dims = GetDimNames(x.schedule(), isl_dim_out);
std::string id = isl_map_get_tuple_name(x.schedule().get(), isl_dim_in);
schedule_graph_.RegisterNode(x.id(),
common::make_shared<ScheduleGraphNode>(id, GetDimNames(x.schedule(), isl_dim_out)));

if (!ctx_.get()) {
ctx_ = x.domain().ctx();
} else {
CHECK_EQ(ctx_.get(), x.domain().ctx().get()) << "isl ctx not match";
}
}

void Scheduler::FinalizeRegistration() {
Expand All @@ -69,7 +109,7 @@ void Scheduler::FinalizeRegistration() {
Scheduler &Scheduler::After(const Element &a, const Element &b, int level) {
CHECK_LT(level, space_size_);
auto *a_node = schedule_graph_.RetriveNode(a.id())->As<ScheduleGraphNode>();
auto *b_node = schedule_graph_.RetriveNode(a.id())->As<ScheduleGraphNode>();
auto *b_node = schedule_graph_.RetriveNode(b.id())->As<ScheduleGraphNode>();
CHECK(a_node) << "no node called " << a.id() << " registered in the graph";
CHECK(b_node) << "no node called " << b.id() << " registered in the graph";

Expand All @@ -82,7 +122,32 @@ Scheduler &Scheduler::After(const Element &a, const Element &b, int level) {

Scheduler &Scheduler::Before(const Element &a, const Element &b, int level) { return After(b, a, level); }

std::unordered_map<std::string, isl::map> Scheduler::BuildSchedule() const {}
std::map<std::string, isl::map> Scheduler::BuildSchedule() const {
std::map<std::string, isl::map> res;
CHECK(ctx_.get());

ScheduleGraph::node_order_t node_order;
ScheduleGraph::edge_order_t edge_order;
std::tie(node_order, edge_order) = schedule_graph_.topological_order();
for (auto *edge : edge_order) {
auto *schedule_edge = edge->As<ScheduleGraphEdge>();
auto *a_node = schedule_graph_.RetriveNode(edge->source()->As<ScheduleGraphNode>()->time_schedule.id())
->As<ScheduleGraphNode>();
auto *b_node =
schedule_graph_.RetriveNode(edge->sink()->As<ScheduleGraphNode>()->time_schedule.id())->As<ScheduleGraphNode>();
CHECK(a_node);
CHECK(b_node);

int level = schedule_edge->level;
b_node->time_schedule.OrderAfter(a_node->time_schedule, level);
}

for (auto *node : schedule_graph_.nodes()) {
auto *schedule_node = node->As<ScheduleGraphNode>();
res[schedule_node->time_schedule.id()] = schedule_node->time_schedule.to_isl(ctx_);
}
return res;
}

} // namespace poly
} // namespace cinn
38 changes: 20 additions & 18 deletions cinn/poly/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,29 @@ struct ScheduleGraph : public common::Graph {};
* The range of the schedule.
*/
struct TimeSchedule {
//! ISL range format, such as '[dup, t0, t1]: dup=0 and t0=0 and t1=i]'
std::string __str__() const;

TimeSchedule(const std::vector<std::string> &dims) {
domain_dims = dims;
for (auto &dim : domain_dims) {
time_dims.emplace_back(dim, 0);
}
}
TimeSchedule(const std::string &id, const std::vector<std::string> &dims);

void ResizeTimeSpace(int size) { time_dims.resize(size); }

//! Schedule this after \p other in \p level.
void OrderAfter(const TimeSchedule &other, int level);

size_t space_size() const { return time_dims.size(); }

const std::string &id() const;

//! Get the isl map.
isl::map to_isl(isl::ctx ctx) const { return isl::map(ctx, __str__()); }
isl::map to_isl(isl::ctx ctx) const;

//! ISL range format, such as '[dup, t0, t1]: dup=0 and t0=0 and t1=i]'
std::string __str__() const;

std::string id;
std::vector<std::string> domain_dims;
int duplicate_id{};
std::vector<TimeDim> time_dims;

private:
std::string id_;
};

/**
Expand All @@ -73,7 +77,7 @@ class Scheduler {
* '{ S[i,j] -> [i_outer, i_inner, j]: i_outer=floor(i/4) and i_inner=i%4 }'
* that's OK.
*/
Scheduler() = default;
Scheduler() : ctx_(nullptr) {}

/**
* Register an Element to the scheduler.
Expand Down Expand Up @@ -102,7 +106,7 @@ class Scheduler {
/**
* Build and create schedule.
*/
std::unordered_map<std::string, isl::map> BuildSchedule() const;
std::map<std::string, isl::map> BuildSchedule() const;

private:
/**
Expand All @@ -114,12 +118,10 @@ class Scheduler {
int space_size_{};
//! Tell if the element registration is finalized.
bool registration_finalized_{false};
//! map from Schedule id to time schedule.
std::unordered_map<std::string, TimeSchedule> schedule_flows_;
//! Reversed dependency flow.
std::unordered_map<std::string, TimeSchedule> rev_schedule_flows_;

ScheduleGraph schedule_graph_;
mutable isl::ctx ctx_;

mutable ScheduleGraph schedule_graph_;
};

} // namespace poly
Expand Down
21 changes: 20 additions & 1 deletion cinn/poly/schedule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,26 @@
namespace cinn {
namespace poly {

TEST(Schedule, basic) {}
TEST(Schedule, basic) {
isl::ctx ctx(isl_ctx_alloc());
isl::set A_set(ctx, "[]->{ A[i,j]: 0<i,j<100 }");
Element A(A_set);
isl::set B_set(ctx, "[]->{ B[i,j]: 0<i,j<100 }");
Element B(B_set);
LOG(INFO) << A.schedule();

Scheduler scheduler;
scheduler.RegisterElement(A);
scheduler.RegisterElement(B);

scheduler.After(A, B, 1);

auto schedule = scheduler.BuildSchedule();

for (auto item : schedule) {
LOG(INFO) << item.first << " " << item.second;
}
}

} // namespace poly
} // namespace cinn

0 comments on commit fc24b37

Please sign in to comment.