Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Meta schedule] improve search space #1

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/search_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class SearchStrategy : public runtime::ObjectRef {
* \param num_trials_total The total number of trials for evolutionary search.
* \param population_size The initial sample population.
* \param init_measured_ratio The ratio of measures samples in initial population.
* \param init_max_fail_count The maximum number to fail trace replaying.
* \param init_min_unmeasured The minimal size of unmeasured population in the initial sampling.
* \param genetic_num_iters The iterations to run the genetic algorithm.
* \param genetic_mutate_prob The probability of mutation.
* \param genetic_max_fail_count The maximum number to try evolving the given trace.
Expand All @@ -277,7 +277,7 @@ class SearchStrategy : public runtime::ObjectRef {
int num_trials_total, //
int population_size, //
double init_measured_ratio, //
int init_max_fail_count, //
int init_min_unmeasured, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/meta_schedule/search_strategy/evolutionary_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class EvolutionarySearch(SearchStrategy):
The initial population of traces from measured samples and randomly generated samples.
init_measured_ratio : int
The ratio of measured samples in the initial population.
init_max_fail_count : int
The maximum number to fail trace replaying.
init_min_unmeasured : int
The minimal size of unmeasured population in the initial sampling.
genetic_num_iters : int
The number of iterations for genetic algorithm.
genetic_mutate_prob : float
Expand All @@ -56,7 +56,7 @@ class EvolutionarySearch(SearchStrategy):
num_trials_total: int
population_size: int
init_measured_ratio: int
init_max_fail_count: int
init_min_unmeasured: int
genetic_num_iters: int
genetic_mutate_prob: float
genetic_max_fail_count: int
Expand All @@ -69,7 +69,7 @@ def __init__(
num_trials_total: int,
population_size: int,
init_measured_ratio: float,
init_max_fail_count: int,
init_min_unmeasured: int,
genetic_num_iters: int,
genetic_mutate_prob: float,
genetic_max_fail_count: int,
Expand All @@ -82,7 +82,7 @@ def __init__(
num_trials_total,
population_size,
init_measured_ratio,
init_max_fail_count,
init_min_unmeasured,
genetic_num_iters,
genetic_mutate_prob,
genetic_max_fail_count,
Expand All @@ -97,7 +97,7 @@ class EvolutionarySearchConfig(NamedTuple):
num_trials_total: int
population_size: int = 2048
init_measured_ratio: float = 0.2
init_max_fail_count: int = 64
init_min_unmeasured: int = 50
genetic_num_iters: int = 4
genetic_mutate_prob: float = 0.85
genetic_max_fail_count: int = 10
Expand All @@ -109,7 +109,7 @@ def create_strategy(self) -> EvolutionarySearch:
num_trials_total=self.num_trials_total,
population_size=self.population_size,
init_measured_ratio=self.init_measured_ratio,
init_max_fail_count=self.init_max_fail_count,
init_min_unmeasured=self.init_min_unmeasured,
genetic_num_iters=self.genetic_num_iters,
genetic_mutate_prob=self.genetic_mutate_prob,
genetic_max_fail_count=self.genetic_max_fail_count,
Expand Down
15 changes: 14 additions & 1 deletion src/meta_schedule/measure_callback/update_cost_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@ class UpdateCostModelNode : public MeasureCallbackNode {
ICHECK(task->measure_candidates.defined()) //
<< "Task's measure candidates must be present!";
CostModel cost_model = task_scheduler->cost_model.value();
cost_model->Update(task, task->measure_candidates.value(), runner_results);
ICHECK_EQ(measure_candidates.size(), builder_results.size());
ICHECK_EQ(runner_results.size(), builder_results.size());
int n = builder_results.size();
Array<MeasureCandidate> pruned_candidate;
Array<RunnerResult> pruned_runner_result;
pruned_candidate.reserve(n);
pruned_runner_result.reserve(n);
for (int i = 0; i < n; i++) {
if (!builder_results[i]->error_msg.defined()) {
pruned_candidate.push_back(measure_candidates[i]);
pruned_runner_result.push_back(runner_results[i]);
}
}
cost_model->Update(task, pruned_candidate, pruned_runner_result);
}

static constexpr const char* _type_key = "meta_schedule.UpdateCostModel";
Expand Down
121 changes: 76 additions & 45 deletions src/meta_schedule/mutator/mutate_tile_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <mutex>
#include <unordered_map>

#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -100,6 +103,49 @@ bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instructi
return false;
}

struct FactorMemo {
static std::vector<int> Factorize(int n) {
if (const std::vector<int>* result = Global()->Query(n)) {
return *result;
}
std::vector<int> result;
for (int64_t i = 1; i * i < n; ++i) {
if (n % i == 0) {
result.push_back(i);
if (i * i != n) {
result.push_back(n / i);
}
}
}
std::sort(result.begin(), result.end());
Global()->Add(n, result);
return result;
}

private:
const std::vector<int>* Query(int n) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = memo_.find(n);
if (it != memo_.end()) {
return &it->second;
}
return nullptr;
}

void Add(int n, std::vector<int> result) {
std::unique_lock<std::mutex> lock(mutex_);
memo_.emplace(n, std::move(result));
}

static FactorMemo* Global() {
static FactorMemo singleton;
return &singleton;
}

std::unordered_map<int, std::vector<int>> memo_;
std::mutex mutex_;
};

Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
Instruction inst;
std::vector<int64_t> tiles;
Expand All @@ -108,59 +154,44 @@ Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s
}
int n_splits = tiles.size();
// Step 1. Choose two loops, `x` and `y`
int x = tir::SampleInt(rand_state, 0, n_splits);
int y;
if (tiles[x] == 1) {
// need to guarantee that tiles[x] * tiles[y] > 1
std::vector<int> idx;
idx.reserve(n_splits);
for (int i = 0; i < n_splits; ++i) {
if (tiles[i] > 1) {
idx.push_back(i);
}
int x, y;
// select source
while (true) {
x = tir::SampleInt(rand_state, 0, n_splits);
if (tiles[x] <= 1) {
continue;
}
y = idx[tir::SampleInt(rand_state, 0, idx.size())];
} else {
// sample without replacement
y = tir::SampleInt(rand_state, 0, n_splits - 1);
if (y >= x) {
++y;
}
}
// make sure x < y
CHECK_NE(x, y);
if (x > y) {
std::swap(x, y);
}
// Step 2. Choose the new tile size
int64_t len_x, len_y;
if (y != n_splits - 1) {
// Case 1. None of x and y are innermost loop
do {
std::vector<int64_t> result = tir::SamplePerfectTile(rand_state, tiles[x] * tiles[y], 2);
len_x = result[0];
len_y = result[1];
} while (len_y == tiles[y]);
} else {
// Case 2. y is the innermost loop
std::vector<int64_t> len_y_space;
int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
int64_t prod = tiles[x] * tiles[y];
for (len_y = 1; len_y <= limit; ++len_y) {
if (len_y != tiles[y] && prod % len_y == 0) {
len_y_space.push_back(len_y);
std::vector<int> factors = FactorMemo::Factorize(tiles[x]);
// Step 2. Choose the divide factor
int64_t divide_factor;
if (y != n_splits - 1) {
divide_factor = factors[tir::SampleInt(rand_state, 1, factors.size())];
} else {
int64_t limit = Downcast<Integer>(inst->attrs[1])->value;
int max_factor_index = static_cast<int>(factors.size()) - 1;
for (; max_factor_index >= 1; max_factor_index--) {
if (factors[max_factor_index] * tiles[y] <= limit) {
break;
}
}
if (max_factor_index == 0) {
if (n_splits <= 2) {
return NullOpt;
}
// Failed on this dst_idx, try next one.
continue;
}
divide_factor = factors[tir::SampleInt(rand_state, 1, max_factor_index + 1)];
}
if (len_y_space.empty()) {
return NullOpt;
}
len_y = len_y_space[tir::SampleInt(rand_state, 0, len_y_space.size())];
len_x = prod / len_y;
tiles[x] /= divide_factor;
tiles[y] *= divide_factor;
return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
/*remove_postproc=*/true);
}
tiles[x] = len_x;
tiles[y] = len_y;
return trace->WithDecision(inst, support::AsArray<int64_t, ObjectRef>(tiles),
/*remove_postproc=*/true);
}

Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }
Expand Down
46 changes: 23 additions & 23 deletions src/meta_schedule/search_strategy/evolutionary_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ class EvolutionarySearchNode : public SearchStrategyNode {
/*** Configuration: the initial population ***/
/*! \brief The ratio of measured states used in the initial population */
double init_measured_ratio;
/*! \brief The maximum number to fail trace replaying. */
int init_max_fail_count;
/*! \brief The minimal size of unmeasured population in the initial sampling.*/
int init_min_unmeasured;
/*** Configuration: evolution ***/
/*! \brief The number of iterations performed by generic algorithm. */
int genetic_num_iters;
Expand Down Expand Up @@ -346,7 +346,7 @@ class EvolutionarySearchNode : public SearchStrategyNode {
v->Visit("population_size", &population_size);
/*** Configuration: the initial population ***/
v->Visit("init_measured_ratio", &init_measured_ratio);
v->Visit("init_max_fail_count", &init_max_fail_count);
v->Visit("init_min_unmeasured", &init_min_unmeasured);
/*** Configuration: evolution ***/
v->Visit("genetic_num_iters", &genetic_num_iters);
v->Visit("genetic_mutate_prob", &genetic_mutate_prob);
Expand Down Expand Up @@ -445,29 +445,30 @@ std::vector<Schedule> EvolutionarySearchNode::State::PickBestFromDatabase(int nu

std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int num) {
ThreadedTraceApply pp(self->postprocs_);
std::vector<Schedule> results(num, Schedule{nullptr});
auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
PerThreadData& data = self->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
Schedule& result = results.at(trace_id);
ICHECK(!result.defined());
for (int fail_count = 0; fail_count <= self->init_max_fail_count; ++fail_count) {
std::vector<Schedule> out_schs;
while (static_cast<int>(out_schs.size()) < self->init_min_unmeasured) {
std::vector<Schedule> results(num, Schedule{nullptr});
auto f_proc_unmeasured = [this, &results, &pp](int thread_id, int trace_id) -> void {
PerThreadData& data = self->per_thread_data_.at(thread_id);
TRandState* rand_state = &data.rand_state;
const IRModule& mod = data.mod;
Schedule& result = results.at(trace_id);
ICHECK(!result.defined());
int design_space_index = tir::SampleInt(rand_state, 0, design_spaces.size());
tir::Trace trace(design_spaces[design_space_index]->insts, {});
if (Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
result = sch.value();
break;
}
};
support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured);
for (int i = 0; i < num; i++) {
if (results[i].defined()) {
out_schs.push_back(results[i]);
}
}
if (!result.defined()) {
LOG(FATAL) << "Sample-Init-Population failed over the maximum limit! Summary:\n"
<< pp.SummarizeFailures();
}
};
support::parallel_for_dynamic(0, num, self->num_threads_, f_proc_unmeasured);
LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures();
return results;
LOG(INFO) << "Sample-Init-Population summary:\n" << pp.SummarizeFailures();
}
return out_schs;
}

std::vector<Schedule> EvolutionarySearchNode::State::EvolveWithCostModel(
Expand Down Expand Up @@ -627,7 +628,6 @@ Optional<Array<MeasureCandidate>> EvolutionarySearchNode::State::GenerateMeasure
LOG(INFO) << "Sampled " << unmeasured.size() << " candidate(s)";
inits.insert(inits.end(), measured.begin(), measured.end());
inits.insert(inits.end(), unmeasured.begin(), unmeasured.end());
ICHECK_EQ(inits.size(), self->population_size);
std::vector<Schedule> bests = EvolveWithCostModel(inits, sample_num);
LOG(INFO) << "Got " << bests.size() << " candidate(s) with evolutionary search";
std::vector<Schedule> picks = PickWithEpsGreedy(unmeasured, bests, sample_num);
Expand All @@ -646,7 +646,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, /
int num_trials_total, //
int population_size, //
double init_measured_ratio, //
int init_max_fail_count, //
int init_min_unmeasured, //
int genetic_num_iters, //
double genetic_mutate_prob, //
int genetic_max_fail_count, //
Expand All @@ -659,7 +659,7 @@ SearchStrategy SearchStrategy::EvolutionarySearch(int num_trials_per_iter, /
n->num_trials_total = num_trials_total;
n->population_size = population_size;
n->init_measured_ratio = init_measured_ratio;
n->init_max_fail_count = init_max_fail_count;
n->init_min_unmeasured = init_min_unmeasured;
n->genetic_num_iters = genetic_num_iters;
n->genetic_max_fail_count = genetic_max_fail_count;
n->genetic_mutate_prob = genetic_mutate_prob;
Expand Down
18 changes: 4 additions & 14 deletions src/tir/schedule/primitive/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,22 +300,12 @@ std::vector<int64_t> SamplePerfectTile(support::LinearCongruentialEngine::TRandS
return SamplePerfectTile(rand_state, extent, n_splits);
}
CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits";
std::vector<int32_t> innermost_candidates;
innermost_candidates.reserve(max_innermost_factor);
for (int32_t i = 1; i <= max_innermost_factor; ++i) {
if (extent % i == 0) {
innermost_candidates.push_back(i);
while (true) {
std::vector<int64_t> result = SamplePerfectTile(rand_state, extent, n_splits);
if (result.back() <= max_innermost_factor) {
return result;
}
}
// N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space.
// We should do multiple factorization to weight the choices. However, it would lead to slower
// sampling speed. On the other hand, considering potential tricks we might do on the innermost
// loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add
// more heuristics in the future
int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())];
std::vector<int64_t> result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1);
result.push_back(innermost);
return result;
}

std::vector<int64_t> SamplePerfectTile(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/meta_schedule/test_meta_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def main():
config=ms.EvolutionarySearchConfig(
num_trials_per_iter=64,
num_trials_total=ARGS.num_trials,
init_max_fail_count=8192,
init_min_unmeasured=50
),
runner=runner,
task_name=ARGS.workload,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def predict(
num_trials_total=num_trials_total,
population_size=5,
init_measured_ratio=0.1,
init_max_fail_count=10,
init_min_unmeasured=50,
genetic_num_iters=3,
genetic_mutate_prob=0.5,
genetic_max_fail_count=10,
Expand Down