diff --git a/cinn/poly/element.cc b/cinn/poly/element.cc index 307acf2aa481c..0820bc76017cb 100644 --- a/cinn/poly/element.cc +++ b/cinn/poly/element.cc @@ -1,5 +1,6 @@ #include "cinn/poly/element.h" #include "cinn/poly/isl_utils.h" +#include "cinn/utils/functional.h" namespace cinn { namespace poly { @@ -38,7 +39,7 @@ Element::Element(isl::set domain) : domain_(domain) { std::tuple Element::Split(const Iterator &level, int factor) { int offset = isl_set_find_dim_by_name(domain_.get(), isl_dim_set, level.id.c_str()); CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_; - auto dim_names = GetDimNames(domain_); + auto dim_names = GetDimNames(schedule_, isl_dim_out); VLOG(2) << "domain: " << domain_; VLOG(2) << "schedule: " << schedule_; @@ -65,14 +66,18 @@ std::tuple Element::Split(const Iterator &level, int factor) } } - Map transform(domain_.ctx(), "", from_iters, to_iters, conds, ""); + Map transform(domain_.ctx(), id(), from_iters, to_iters, conds, id()); VLOG(3) << "transform: " << transform.__str__(); schedule_ = schedule_.apply_range(transform.to_isl()); + auto range_dims = + utils::Map, std::vector>(to_iters, [](const Iterator &x) { return x.id; }); + SetDimNames(&schedule_, isl_dim_out, range_dims); VLOG(3) << "transform " << transform.to_isl(); VLOG(3) << "schedule after transform: " << schedule_; - std::make_tuple(outer_iter, inner_iter); + VLOG(3) << "iterators: " << outer_iter << " " << inner_iter; + return std::make_tuple(outer_iter, inner_iter); } void Element::Reorder(const std::vector &order) {} @@ -107,5 +112,9 @@ std::string OuterName(const Iterator &iterator) { return OuterName(iterator.id); const char *Element::id() const { return isl_set_get_tuple_name(domain_.get()); } +std::tuple Element::Split(const std::string &level, int factor) { + return std::move(Split(Iterator(level), factor)); +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/element.h b/cinn/poly/element.h index 00c909dda1fdd..5d47ab458bd74 100644 --- a/cinn/poly/element.h +++ b/cinn/poly/element.h @@ -34,6 +34,8 @@ class Element { */ std::tuple // Split(const Iterator& level, int factor); + std::tuple // + Split(const std::string& level, int factor); /** * Reorder the iterators. diff --git a/cinn/poly/isl_utils.cc b/cinn/poly/isl_utils.cc index 3f16aa4028bcc..af943da12a421 100644 --- a/cinn/poly/isl_utils.cc +++ b/cinn/poly/isl_utils.cc @@ -1,4 +1,5 @@ #include "cinn/poly/isl_utils.h" +#include #include namespace cinn { @@ -20,5 +21,23 @@ std::vector GetDimNames(const isl::map &x, isl_dim_type dim_type) { return res; } +void SetDimNames(isl::map *map, isl_dim_type dim_type, const std::vector &names) { + const int dim = isl_map_dim(map->get(), dim_type); + CHECK_EQ(dim, names.size()); + + for (int i = 0; i < dim; i++) { + *map = isl::manage(isl_map_set_dim_name(map->release(), dim_type, i, names[i].c_str())); + } +} + +void SetDimNames(isl::set *set, const std::vector &names) { + int dim = isl_set_dim(set->get(), isl_dim_set); + CHECK_EQ(dim, names.size()); + + for (int i = 0; i < dim; i++) { + *set = isl::manage(isl_set_set_dim_name(set->release(), isl_dim_set, i, names[i].c_str())); + } +} + } // namespace poly } // namespace cinn diff --git a/cinn/poly/isl_utils.h b/cinn/poly/isl_utils.h index bd4a4b4fb9807..948db9b617192 100644 --- a/cinn/poly/isl_utils.h +++ b/cinn/poly/isl_utils.h @@ -9,9 +9,12 @@ namespace poly { //! Get dimension names from isl containers. // @{ -std::vector GetDimNames(const isl::set &x); -std::vector GetDimNames(const isl::map &x, isl_dim_type dim_type); +std::vector GetDimNames(const isl::set& x); +std::vector GetDimNames(const isl::map& x, isl_dim_type dim_type); // @} +void SetDimNames(isl::set* set, const std::vector& names); +void SetDimNames(isl::map* map, isl_dim_type dim_type, const std::vector& names); + } // namespace poly } // namespace cinn diff --git a/cinn/poly/map.h b/cinn/poly/map.h index c3e464bb801c0..0d557c6f7dceb 100644 --- a/cinn/poly/map.h +++ b/cinn/poly/map.h @@ -16,7 +16,9 @@ namespace poly { struct Iterator { std::string id; - explicit Iterator(std::string id) : id(std::move(id)) {} + explicit Iterator(const std::string& id) : id(id) {} + explicit Iterator(const Iterator& x) : id(x.id) {} + explicit Iterator(Iterator&& x) : id(std::move(x.id)) {} friend std::ostream& operator<<(std::ostream& os, const Iterator& x); }; @@ -25,7 +27,7 @@ struct Condition { Iterator iterator; std::string cond; - Condition(Iterator iterator, std::string cond) : iterator(std::move(iterator)), cond(std::move(cond)) {} + Condition(const Iterator& iterator, std::string cond) : iterator(iterator), cond(std::move(cond)) {} friend std::ostream& operator<<(std::ostream& os, const Condition& x) { os << x.__str__(); diff --git a/cinn/poly/schedule.cc b/cinn/poly/schedule.cc index ef430b91617c0..7a9aad31539ab 100644 --- a/cinn/poly/schedule.cc +++ b/cinn/poly/schedule.cc @@ -80,6 +80,7 @@ void Scheduler::RegisterElement(const Element &x) { CHECK(!registration_finalized_) << "element registration has been finalized."; space_size_ = std::max(space_size_, isl_map_dim(x.schedule().get(), isl_dim_out)); VLOG(3) << "space_size: " << space_size_; + VLOG(3) << "schedule: " << x.schedule(); // Use the dimensions from element's schedule's range as the new domain dimensions because in Element, the schedule is // like '{ S0[i,j] -> S0[i_outer, i_inner, j] }', the scheduler should schedule base on the range. diff --git a/cinn/poly/schedule_test.cc b/cinn/poly/schedule_test.cc index 410cb352be0ab..d76acd008f005 100644 --- a/cinn/poly/schedule_test.cc +++ b/cinn/poly/schedule_test.cc @@ -20,10 +20,37 @@ TEST(Schedule, basic) { auto schedule = scheduler.BuildSchedule(); + EXPECT_EQ(utils::GetStreamCnt(schedule["A"]), "{ A[i, j] -> [t0 = 0, d0 = i, t1 = 0, d1 = j] }"); + EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), "{ B[i, j] -> [t0 = 0, d0 = i, t1 = 1, d1 = j] }"); + for (auto item : schedule) { LOG(INFO) << item.first << " " << item.second; } } +TEST(Schedule, basic_with_transform) { + isl::ctx ctx(isl_ctx_alloc()); + Element A(isl::set(ctx, "[]->{ A[i,j]: 0{ B[i,j]: 0 [t0 = 0, d0 = i_outer, t1 = 0, d1 = i_inner, t2 = 0, d2 = j] }"); + EXPECT_EQ(utils::GetStreamCnt(schedule["B"]), + "{ B[i, j_outer, j_inner] -> [t0 = 0, d0 = i, t1 = 1, d1 = j_outer, t2 = 0, d2 = j_inner] }"); +} + } // namespace poly } // namespace cinn diff --git a/cinn/utils/CMakeLists.txt b/cinn/utils/CMakeLists.txt index b1d0fa2c0d4f1..47fc6e48dd767 100644 --- a/cinn/utils/CMakeLists.txt +++ b/cinn/utils/CMakeLists.txt @@ -1,3 +1,4 @@ cc_library(utils SRCS string.cc target.cc + functional.cc ) diff --git a/cinn/utils/functional.cc b/cinn/utils/functional.cc new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/cinn/utils/functional.cc @@ -0,0 +1 @@ + diff --git a/cinn/utils/functional.h b/cinn/utils/functional.h new file mode 100644 index 0000000000000..aa1658badfe11 --- /dev/null +++ b/cinn/utils/functional.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +namespace cinn { +namespace utils { + +template +OutT Map(const InT& in, std::function fn) { + OutT res; + std::transform( + in.begin(), in.end(), std::back_inserter(res), [&](const typename InT::value_type& x) { return fn(x); }); + return res; +} + +} // namespace utils +} // namespace cinn