diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h new file mode 100644 index 000000000000..9ee7039959e5 --- /dev/null +++ b/include/tvm/meta_schedule/measure_callback.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ +#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace meta_schedule { + +class TaskScheduler; + +/*! \brief Rules to apply after measure results is available. */ +class MeasureCallbackNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MeasureCallbackNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief Apply a measure callback rule with given arguments. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + virtual bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) = 0; + + static constexpr const char* _type_key = "meta_schedule.MeasureCallback"; + TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object); +}; + +/*! \brief The measure callback with customized methods on the python-side. */ +class PyMeasureCallbackNode : public MeasureCallbackNode { + public: + /*! + * \brief Apply a measure callback to the given schedule. + * \param task_scheduler The task scheduler. + * \param tasks The list of tune context to process. + * \param measure_candidates The measure candidates. + * \param builds The builder results by building the measure candidates. + * \param results The runner results by running the built measure candidates. + * \return Whether the measure callback was successfully applied. + */ + using FApply = + runtime::TypedPackedFunc tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results)>; + /*! + * \brief Get the measure callback function as string with name. + * \return The string of the measure callback function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_apply` is not visited + // `f_as_string` is not visited + } + + bool Apply(const TaskScheduler& task_scheduler, // + const Array tasks, // + const Array& measure_candidates, // + const Array& builds, // + const Array& results) final { + ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!"; + return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results); + } + + static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode); +}; + +/*! + * \brief Managed reference to MeasureCallbackNode + * \sa MeasureCallbackNode + */ +class MeasureCallback : public runtime::ObjectRef { + public: + /*! + * \brief Create a measure callback with customized methods on the python-side. + * \param f_apply The packed function of `Apply`. + * \return The measure callback created. + */ + TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_ diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h new file mode 100644 index 000000000000..82f5b7683412 --- /dev/null +++ b/include/tvm/meta_schedule/mutator.h @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_MUTATOR_H_ +#define TVM_META_SCHEDULE_MUTATOR_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Mutator is designed to mutate the trace to explore the design space. */ +class MutatorNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~MutatorNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + virtual Optional Apply(const tir::Trace& trace) = 0; + + static constexpr const char* _type_key = "meta_schedule.Mutator"; + TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object); +}; + +/*! \brief The mutator with customized methods on the python-side. */ +class PyMutatorNode : public MutatorNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply the mutator function to the given trace. + * \param trace The given trace for mutation. + * \return None if mutator failed, otherwise return the mutated trace. + */ + using FApply = runtime::TypedPackedFunc(const tir::Trace&)>; + /*! + * \brief Get the mutator as string with name. + * \return The string of the mutator. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyMutator's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Optional Apply(const tir::Trace& trace) final { + ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!"; + return this->f_apply(trace); + } + + static constexpr const char* _type_key = "meta_schedule.PyMutator"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode); +}; + +/*! + * \brief Managed reference to MutatorNode + * \sa MutatorNode + */ +class Mutator : public runtime::ObjectRef { + public: + /*! + * \brief Create a mutator with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The mutator created. + */ + TVM_DLL static Mutator PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_MUTATOR_H_ diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h new file mode 100644 index 000000000000..c24861d69784 --- /dev/null +++ b/include/tvm/meta_schedule/postproc.h @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_POSTPROC_H_ +#define TVM_META_SCHEDULE_POSTPROC_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! + * \brief Rules to apply a post processing to a schedule. + * \note Post processing is designed to deal with the problem of undertermined schedule validity + * after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the + * maximum number below 1024, X is only decided at runtime. + */ +class PostprocNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~PostprocNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a post processing to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the post processing was successfully applied. + */ + virtual bool Apply(const tir::Schedule& schedule) = 0; + + static constexpr const char* _type_key = "meta_schedule.Postproc"; + TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object); +}; + +/*! \brief The post processing with customized methods on the python-side. */ +class PyPostprocNode : public PostprocNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief Apply a post processing to the given schedule. + * \param sch The schedule to be post processed. + * \return Whether the post processing was successfully applied. + */ + using FApply = runtime::TypedPackedFunc; + /*! + * \brief Get the post processing function as string with name. + * \return The string of the post processing function. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyPostproc's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + bool Apply(const tir::Schedule& sch) final { + ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!"; + return this->f_apply(sch); + } + + static constexpr const char* _type_key = "meta_schedule.PyPostproc"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode); +}; + +/*! + * \brief Managed reference to PostprocNode + * \sa PostprocNode + */ +class Postproc : public runtime::ObjectRef { + public: + /*! + * \brief Create a post processing with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The post processing created. + */ + TVM_DLL static Postproc PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_POSTPROC_H_ diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h new file mode 100644 index 000000000000..92aa46beeaf6 --- /dev/null +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_H_ + +#include + +namespace tvm { +namespace meta_schedule { + +class TuneContext; + +/*! \brief Rules to modify a block in a schedule. */ +class ScheduleRuleNode : public runtime::Object { + public: + /*! \brief Virtual destructor. */ + virtual ~ScheduleRuleNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + virtual void InitializeWithTuneContext(const TuneContext& context) = 0; + + /*! + * \brief Apply a schedule rule to the specific block in the given schedule. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + virtual runtime::Array Apply(const tir::Schedule& sch, + const tir::BlockRV& block) = 0; + + static constexpr const char* _type_key = "meta_schedule.ScheduleRule"; + TVM_DECLARE_BASE_OBJECT_INFO(ScheduleRuleNode, Object); +}; + +/*! \brief The schedule rule with customized methods on the python-side. */ +class PyScheduleRuleNode : public ScheduleRuleNode { + public: + /*! + * \brief The function type of `InitializeWithTuneContext` method. + * \param tune_context The tuning context for initialization. + */ + using FInitializeWithTuneContext = runtime::TypedPackedFunc; + /*! + * \brief The function type of `Apply` method. + * \param sch The schedule to be modified. + * \param block The specific block to apply the schedule rule. + * \return The list of schedules generated by applying the schedule rule. + */ + using FApply = + runtime::TypedPackedFunc(const tir::Schedule&, const tir::BlockRV&)>; + /*! + * \brief Get the schedule rule as string with name. + * \return The string of the schedule rule. + */ + using FAsString = runtime::TypedPackedFunc; + + /*! \brief The packed function to the `InitializeWithTuneContext` funcion. */ + FInitializeWithTuneContext f_initialize_with_tune_context; + /*! \brief The packed function to the `Apply` funcion. */ + FApply f_apply; + /*! \brief The packed function to the `AsString` funcion. */ + FAsString f_as_string; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `f_initialize_with_tune_context` is not visited + // `f_apply` is not visited + // `f_as_string` is not visited + } + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(f_initialize_with_tune_context != nullptr) + << "PyScheduleRule's InitializeWithTuneContext method not implemented!"; + this->f_initialize_with_tune_context(context); + } + + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final { + ICHECK(f_apply != nullptr) << "PyScheduleRule's Apply method not implemented!"; + return this->f_apply(sch, block); + } + + static constexpr const char* _type_key = "meta_schedule.PyScheduleRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PyScheduleRuleNode, ScheduleRuleNode); +}; + +/*! + * \brief Managed reference to ScheduleRuleNode + * \sa ScheduleRuleNode + */ +class ScheduleRule : public runtime::ObjectRef { + public: + /*! + * \brief Create a schedule rule with customized methods on the python-side. + * \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`. + * \param f_apply The packed function of `Apply`. + * \return The schedule rule created. + */ + TVM_DLL static ScheduleRule PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, ScheduleRuleNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_H_ diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 0f3e9298d11a..3a0fa0ab4a64 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -21,6 +21,7 @@ #include #include +#include #include namespace tvm { @@ -247,6 +248,13 @@ class SearchStrategy : public runtime::ObjectRef { */ TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int num_trials_total); + /*! + * \brief Constructor of replay func search strategy. + * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. + * \param num_trials_total The total number of trials for func replaying. + */ + TVM_DLL static SearchStrategy ReplayFunc(int num_trials_per_iter, int num_trials_total); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchStrategy, ObjectRef, SearchStrategyNode); }; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index eadf5e91506c..a0dfede82096 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -153,6 +153,14 @@ class SpaceGenerator : public ObjectRef { * \return The design space generator created. */ TVM_DLL static SpaceGenerator SpaceGeneratorUnion(Array space_generators); + /*! + * \brief Create a design space generator that generates design spaces by applying schedule rules + * to blocks in post-DFS order. + * \param initialize_with_tune_context_func The packed function of `InitializeWithTuneContext`. + * \param generate_design_space_func The packed function of `GenerateDesignSpace`. + * \return The design space generator created. + */ + TVM_DLL static SpaceGenerator PostOrderApply(); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 64ba3ddeafb1..e4322554fe89 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -73,6 +74,8 @@ class TaskSchedulerNode : public runtime::Object { Runner runner{nullptr}; /*! \brief The database of the scheduler. */ Database database{nullptr}; + /*! \brief The list of measure callbacks of the scheduler. */ + Array measure_callbacks; /*! \brief The default desctructor. */ virtual ~TaskSchedulerNode() = default; @@ -82,6 +85,7 @@ class TaskSchedulerNode : public runtime::Object { v->Visit("builder", &builder); v->Visit("runner", &runner); v->Visit("database", &database); + v->Visit("measure_callbacks", &measure_callbacks); } /*! \brief Auto-tuning. */ @@ -245,12 +249,14 @@ class TaskScheduler : public runtime::ObjectRef { TVM_DLL static TaskScheduler RoundRobin(Array tasks, // Builder builder, // Runner runner, // - Database database); // + Database database, // + Array measure_callbacks); TVM_DLL static TaskScheduler PyTaskScheduler( Array tasks, // Builder builder, // Runner runner, // Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index db72328c91c3..8ad6aa1d2194 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -20,6 +20,10 @@ #define TVM_META_SCHEDULE_TUNE_CONTEXT_H_ #include +#include +#include +#include +#include #include #include #include @@ -38,6 +42,12 @@ class TuneContextNode : public runtime::Object { Optional space_generator; /*! \brief The search strategy. */ Optional search_strategy; + /*! \brief The schedule rules. */ + Array sch_rules; + /*! \brief The post processings. */ + Array postprocs; + /*! \brief The mutators. */ + Array mutators; /*! \brief The name of the tuning task. */ Optional task_name; /*! \brief The random state. */ @@ -57,6 +67,9 @@ class TuneContextNode : public runtime::Object { v->Visit("target", &target); v->Visit("space_generator", &space_generator); v->Visit("search_strategy", &search_strategy); + v->Visit("sch_rules", &sch_rules); + v->Visit("postprocs", &postprocs); + v->Visit("mutators", &mutators); v->Visit("task_name", &task_name); v->Visit("rand_state", &rand_state); v->Visit("num_threads", &num_threads); @@ -81,6 +94,9 @@ class TuneContext : public runtime::ObjectRef { * \param target The target to be tuned for. * \param space_generator The design space generator. * \param search_strategy The search strategy. + * \param sch_rules The schedule rules. + * \param postprocs The post processings. + * \param mutators The mutators. * \param task_name The name of the tuning task. * \param rand_state The random state. * \param num_threads The number of threads to be used. @@ -89,6 +105,9 @@ class TuneContext : public runtime::ObjectRef { Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads); diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index dfbbf2998509..fa4be5dee47d 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -155,6 +155,12 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Check the existance of a specific BlockRV + * \param block_rv The BlockRV to be looked up + * \return Whether the corresponding block exists + */ + virtual bool HasBlock(const BlockRV& block_rv) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -220,6 +226,18 @@ class ScheduleNode : public runtime::Object { * \return A list of loops above the given block in its scope, from outer to inner */ virtual Array GetLoops(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of a specific scope + * \param block_rv The block where the scope is rooted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const BlockRV& block_rv) = 0; + /*! + * \brief Get the leaf blocks of under a specific loop + * \param loop_rv The loop under which collecting is conducted + * \return A list of child blocks + */ + virtual Array GetChildBlocks(const LoopRV& loop_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Fuse a list of consecutive loops into one. It requires: @@ -315,6 +333,11 @@ class ScheduleNode : public runtime::Object { */ virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) = 0; + /******** Schedule: Data movement ********/ + virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) = 0; + virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) = 0; /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 47b3dda5a36e..c57355a39134 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -19,6 +19,9 @@ from . import database from . import builder from . import runner +from . import mutator +from . import postproc +from . import schedule_rule from . import space_generator from . import search_strategy from . import integration diff --git a/python/tvm/meta_schedule/measure_callback/__init__.py b/python/tvm/meta_schedule/measure_callback/__init__.py new file mode 100644 index 000000000000..f455c1f4c7c3 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.measure_callback package. +""" +from .measure_callback import MeasureCallback, PyMeasureCallback diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py new file mode 100644 index 000000000000..f7daed55f684 --- /dev/null +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule MeasureCallback.""" + +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object + +from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..builder import BuilderResult +from ..runner import RunnerResult +from ..utils import _get_hex_address, check_override + +from .. import _ffi_api + +if TYPE_CHECKING: + from ..task_scheduler import TaskScheduler + + +@register_object("meta_schedule.MeasureCallback") +class MeasureCallback(Object): + """Rules to apply after measure results is available.""" + + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + """Apply a measure callback to the given schedule. + + Parameters + ---------- + task_scheduler: TaskScheduler + The task scheduler. + tasks: List[TuneContext] + The list of tune context to process. + measure_candidats: List[MeasureCandidate] + The measure candidates. + builds: List[BuilderResult] + The builder results by building the measure candidates. + results: List[RunnerResult] + The runner results by running the built measure candidates. + + Returns + ------- + result : bool + Whether the measure callback was successfully applied. + """ + return _ffi_api.MeasureCallbackApply( + self, task_scheduler, tasks, measure_candidates, builds, results + ) + + +@register_object("meta_schedule.PyMeasureCallback") +class PyMeasureCallback(MeasureCallback): + """An abstract MeasureCallback with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, MeasureCallback) + def f_apply( + task_scheduler: "TaskScheduler", + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return self.apply(task_scheduler, tasks, measure_candidates, builds, results) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MeasureCallbackPyMeasureCallback, # type: ignore # pylint: disable=no-member + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMeasureCallback({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/mutator/__init__.py b/python/tvm/meta_schedule/mutator/__init__.py new file mode 100644 index 000000000000..f88043b4b4fd --- /dev/null +++ b/python/tvm/meta_schedule/mutator/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.mutator package. +Meta Schedule mutator that mutates the trace to explore the +design space. +""" +from .mutator import Mutator, PyMutator diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py new file mode 100644 index 000000000000..f583154fec50 --- /dev/null +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Mutator.""" +from typing import Optional, TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Trace + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +class Mutator(Object): + """Mutator is designed to mutate the trace to explore the design space.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the mutator with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the mutator. + """ + _ffi_api.MutatorInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, trace: Trace) -> Optional[Trace]: + """Apply the mutator function to the given trace. + + Parameters + ---------- + trace : Trace + The given trace for mutation. + + Returns + ------- + trace : Optional[Trace] + None if mutator failed, otherwise return the mutated trace. + """ + return _ffi_api.MutatorApply(self, trace) + + +@register_object("meta_schedule.PyMutator") +class PyMutator(Mutator): + """An abstract mutator with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Mutator) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Mutator) + def f_apply(trace: Trace) -> Optional[Trace]: + return self.apply(trace) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.MutatorPyMutator, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyMutator({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py new file mode 100644 index 000000000000..5316eb466373 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.postproc package. +Meta Schedule post processings that deal with the problem of +undertermined schedule validity after applying some schedule +primitves at runtime. +""" +from .postproc import Postproc, PyPostproc diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py new file mode 100644 index 000000000000..06e0da8fd38a --- /dev/null +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Meta Schedule Postproc.""" + +from typing import TYPE_CHECKING + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule + +from .. import _ffi_api +from ..utils import _get_hex_address, check_override + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.Postproc") +class Postproc(Object): + """Rules to apply a post processing to a schedule. + + Note + ---- + Post processing is designed to deal with the problem of undertermined schedule validity after + applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the maximum + number below 1024, X is only decided at runtime. + """ + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the post processing with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the post processing. + """ + _ffi_api.PostprocInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, sch: Schedule) -> bool: + """Apply a post processing to the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be post processed. + + Returns + ------- + result : bool + Whether the post processing was successfully applied. + """ + return _ffi_api.PostprocApply(self, sch) + + +@register_object("meta_schedule.PyPostproc") +class PyPostproc(Postproc): + """An abstract Postproc with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, Postproc) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, Postproc) + def f_apply(sch: Schedule) -> bool: + return self.apply(sch) + + def f_as_string() -> str: + return str(self) + + self.__init_handle_by_constructor__( + _ffi_api.PostprocPyPostproc, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyPostproc({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py new file mode 100644 index 000000000000..34a7590b60c0 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -0,0 +1,19 @@ +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +The tvm.meta_schedule.schedule_rule package. +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from .schedule_rule import ScheduleRule, PyScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py new file mode 100644 index 000000000000..ec101410f671 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Meta Schedule schedule rules are used for modification of +blocks in a schedule. See also PostOrderApply. +""" +from typing import TYPE_CHECKING, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from tvm.tir.schedule import Schedule, BlockRV + +from ..utils import _get_hex_address, check_override +from .. import _ffi_api + +if TYPE_CHECKING: + from ..tune_context import TuneContext + + +@register_object("meta_schedule.ScheduleRule") +class ScheduleRule(Object): + """Rules to modify a block in a schedule.""" + + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + """Initialize the schedule rule with a tune context. + + Parameters + ---------- + tune_context : TuneContext + The tuning context for initializing the schedule rule. + """ + _ffi_api.ScheduleRuleInitializeWithTuneContext( # type: ignore # pylint: disable=no-member + self, tune_context + ) + + def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + """Apply a schedule rule to the specific block in the given schedule. + + Parameters + ---------- + sch : Schedule + The schedule to be modified. + block : BlockRV + The specific block to apply the schedule rule. + + Returns + ------- + design_spaces : List[Schedule] + The list of schedules generated by applying the schedule rule. + """ + return _ffi_api.ScheduleRuleApply(self, schedule, block) + + +@register_object("meta_schedule.PyScheduleRule") +class PyScheduleRule(ScheduleRule): + """An abstract schedule rule with customized methods on the python-side.""" + + def __init__(self): + """Constructor.""" + + @check_override(self.__class__, ScheduleRule) + def f_initialize_with_tune_context(tune_context: "TuneContext") -> None: + self.initialize_with_tune_context(tune_context) + + @check_override(self.__class__, ScheduleRule) + def f_apply(sch: Schedule, block: BlockRV) -> List[Schedule]: + return self.apply(sch, block) + + def f_as_string() -> str: + return self.__str__() + + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRulePyScheduleRule, # type: ignore # pylint: disable=no-member + f_initialize_with_tune_context, + f_apply, + f_as_string, + ) + + def __str__(self) -> str: + return f"PyScheduleRule({_get_hex_address(self.handle)})" diff --git a/python/tvm/meta_schedule/search_strategy/__init__.py b/python/tvm/meta_schedule/search_strategy/__init__.py index 609baa267786..e306c307bc49 100644 --- a/python/tvm/meta_schedule/search_strategy/__init__.py +++ b/python/tvm/meta_schedule/search_strategy/__init__.py @@ -20,5 +20,6 @@ to generate measure candidates. """ -from .search_strategy import SearchStrategy, PySearchStrategy +from .search_strategy import SearchStrategy, PySearchStrategy, MeasureCandidate from .replay_trace import ReplayTrace +from .replay_func import ReplayFunc diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py new file mode 100644 index 000000000000..8edd74ab02f6 --- /dev/null +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Replay Trace Search Strategy""" + +from tvm._ffi import register_object +from .search_strategy import SearchStrategy +from .. import _ffi_api + + +@register_object("meta_schedule.ReplayFunc") +class ReplayFunc(SearchStrategy): + """ + Replay Func Search Strategy is a search strategy that generates measure candidates by + calling a design space generator and transform the design space. + + Parameters + ---------- + num_trials_per_iter : int + Number of trials per iteration. + num_trials_total : int + Total number of trials. + """ + + num_trials_per_iter: int + num_trials_total: int + + def __init__( + self, + num_trials_per_iter: int, + num_trials_total: int, + ): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SearchStrategyReplayFunc, # pylint: disable=no-member + num_trials_per_iter, + num_trials_total, + ) diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 15f8295f2524..3fd8bf7a44b6 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -41,7 +41,7 @@ class ReplayTrace(SearchStrategy): def __init__(self, num_trials_per_iter: int, num_trials_total: int): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.ReplayTrace, # type: ignore # pylint: disable=no-member + _ffi_api.SearchStrategyReplayTrace, # pylint: disable=no-member num_trials_per_iter, num_trials_total, ) diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index 6cee09edd4fc..e92bbbefcabb 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -22,7 +22,7 @@ from tvm._ffi import register_object from tvm.runtime import Object -from tvm.tir.schedule import Schedule +from tvm.tir.schedule import Schedule, Trace from .. import _ffi_api from ..arg_info import ArgInfo @@ -48,7 +48,11 @@ class MeasureCandidate(Object): sch: Schedule args_info: List[ArgInfo] - def __init__(self, sch: Schedule, args_info: List[ArgInfo]) -> None: + def __init__( + self, + sch: Schedule, + args_info: List[ArgInfo], + ) -> None: """Constructor. Parameters @@ -72,10 +76,7 @@ class SearchStrategy(Object): before usage and post-tuned after usage. """ - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the search strategy with tuning context. Parameters diff --git a/python/tvm/meta_schedule/space_generator/__init__.py b/python/tvm/meta_schedule/space_generator/__init__.py index af759d43b34a..fc08cd491de7 100644 --- a/python/tvm/meta_schedule/space_generator/__init__.py +++ b/python/tvm/meta_schedule/space_generator/__init__.py @@ -19,7 +19,7 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ - from .space_generator import SpaceGenerator, PySpaceGenerator from .space_generator_union import SpaceGeneratorUnion from .schedule_fn import ScheduleFn +from .post_order_apply import PostOrderApply diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py new file mode 100644 index 000000000000..a9b2d560314a --- /dev/null +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Post Order Apply Space Generator.""" + + +from tvm._ffi import register_object +from .space_generator import SpaceGenerator +from .. import _ffi_api + + +@register_object("meta_schedule.PostOrderApply") +class PostOrderApply(SpaceGenerator): + """ + PostOrderApply is the design space generator that generates design spaces by applying schedule + rules to blocks in post-DFS order. + """ + + def __init__(self): + """Constructor""" + self.__init_handle_by_constructor__( + _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index e37fd14ba440..2172613ce1e6 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -36,10 +36,7 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" - def initialize_with_tune_context( - self, - tune_context: "TuneContext", - ) -> None: + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: """Initialize the design space generator with tuning context. Parameters diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 391011b4f53f..ab2d11ea7baa 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -19,6 +19,7 @@ from typing import List, TYPE_CHECKING from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from ..builder import Builder from ..runner import Runner @@ -41,6 +42,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + measure_callbacks: List[MeasureCallback] = [], ) -> None: """Constructor. @@ -54,6 +56,8 @@ def __init__( The runner. database : Database The database. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ self.__init_handle_by_constructor__( _ffi_api.TaskSchedulerRoundRobin, # type: ignore # pylint: disable=no-member @@ -61,4 +65,5 @@ def __init__( builder, runner, database, + measure_callbacks, ) diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index aeea154cfe02..149e36662259 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -19,6 +19,7 @@ from typing import List from tvm._ffi import register_object +from tvm.meta_schedule.measure_callback.measure_callback import MeasureCallback from tvm.runtime import Object from ..runner import Runner @@ -43,12 +44,15 @@ class TaskScheduler(Object): The runner of the scheduler. database: Database The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ tasks: List[TuneContext] builder: Builder runner: Runner database: Database + measure_callbacks: List[MeasureCallback] def tune(self) -> None: """Auto-tuning.""" @@ -120,6 +124,7 @@ def __init__( builder: Builder, runner: Runner, database: Database, + measure_callbacks: List[MeasureCallback] = [], ): """Constructor. @@ -133,6 +138,8 @@ def __init__( The runner of the scheduler. database: Database The database of the scheduler. + measure_callbacks: List[MeasureCallback] + The list of measure callbacks of the scheduler. """ @check_override(self.__class__, TaskScheduler, required=False) @@ -173,6 +180,7 @@ def f_join_running_task(task_id: int) -> None: builder, runner, database, + measure_callbacks, f_tune, f_initialize_task, f_set_task_stopped, diff --git a/python/tvm/meta_schedule/testing/__init__.py b/python/tvm/meta_schedule/testing/__init__.py index 7e516a510f66..6a7b27b1f070 100644 --- a/python/tvm/meta_schedule/testing/__init__.py +++ b/python/tvm/meta_schedule/testing/__init__.py @@ -16,4 +16,5 @@ # under the License. """Testing utilities in meta schedule""" from .local_rpc import LocalRPC -from .relay_workload import get_network +from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model +from .te_workload import create_te_workload diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index 1eb9950f7fc7..ec6f65aff0aa 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -15,13 +15,178 @@ # specific language governing permissions and limitations # under the License. """Workloads in Relay IR""" +from enum import Enum from typing import Dict, Tuple -import tvm.relay.testing # pylint: disable=unused-import from tvm import relay from tvm.ir import IRModule from tvm.runtime import NDArray +# Model types supported in Torchvision +class MODEL_TYPE(Enum): # pylint: disable=invalid-name + IMAGE_CLASSIFICATION = (1,) + VIDEO_CLASSIFICATION = (2,) + SEGMENTATION = (3,) + OBJECT_DETECTION = (4,) + + +# Specify the type of each model +MODEL_TYPES = { + # Image classification models + "resnet50": MODEL_TYPE.IMAGE_CLASSIFICATION, + "alexnet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "vgg16": MODEL_TYPE.IMAGE_CLASSIFICATION, + "squeezenet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet121": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet161": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet169": MODEL_TYPE.IMAGE_CLASSIFICATION, + "densenet201": MODEL_TYPE.IMAGE_CLASSIFICATION, + "inception_v3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "googlenet": MODEL_TYPE.IMAGE_CLASSIFICATION, + "shufflenet_v2_x1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_large": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mobilenet_v3_small": MODEL_TYPE.IMAGE_CLASSIFICATION, + "resnext50_32x4d": MODEL_TYPE.IMAGE_CLASSIFICATION, + "wide_resnet50_2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "mnasnet1_0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b0": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b1": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b2": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b3": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b4": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b5": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b6": MODEL_TYPE.IMAGE_CLASSIFICATION, + "efficientnet_b7": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_y_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_400mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_800mf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_1_6gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_3_2gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_8gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_16gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + "regnet_x_32gf": MODEL_TYPE.IMAGE_CLASSIFICATION, + # Semantic Segmentation models + "fcn_resnet50": MODEL_TYPE.SEGMENTATION, + "fcn_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet50": MODEL_TYPE.SEGMENTATION, + "deeplabv3_resnet101": MODEL_TYPE.SEGMENTATION, + "deeplabv3_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + "lraspp_mobilenet_v3_large": MODEL_TYPE.SEGMENTATION, + # Object detection models + # @Sung: Following networks are not runnable since Torch frontend cannot handle aten::remainder. + # "retinanet_resnet50_fpn", "keypointrcnn_resnet50_fpn", + "fasterrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_fpn": MODEL_TYPE.OBJECT_DETECTION, + "fasterrcnn_mobilenet_v3_large_320_fpn": MODEL_TYPE.OBJECT_DETECTION, + "retinanet_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "maskrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "keypointrcnn_resnet50_fpn": MODEL_TYPE.OBJECT_DETECTION, + "ssd300_vgg16": MODEL_TYPE.OBJECT_DETECTION, + "ssdlite320_mobilenet_v3_large": MODEL_TYPE.OBJECT_DETECTION, + # Video classification + "r3d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "mc3_18": MODEL_TYPE.VIDEO_CLASSIFICATION, + "r2plus1d_18": MODEL_TYPE.VIDEO_CLASSIFICATION, +} + + +def get_torch_model( + model_name: str, + input_shape: Tuple[int, ...], + output_shape: Tuple[int, int], # pylint: disable=unused-argument + dtype: str = "float32", +) -> Tuple[IRModule, Dict[str, NDArray]]: + """Load model from torch model zoo + Parameters + ---------- + model_name : str + The name of the model to load + input_shape: Tuple[int, ...] + Tuple for input shape + output_shape: Tuple[int, int] + Tuple for output shape + dtype: str + Tensor data type + """ + + assert dtype == "float32" + + import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel + from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel + + def do_trace(model, inp): + model_trace = torch.jit.trace(model, inp) + model_trace.eval() + return model_trace + + # Load model from torchvision + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + model = getattr(models, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + model = getattr(models.segmentation, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + model = getattr(models.detection, model_name)() + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + model = getattr(models.video, model_name)() + else: + raise ValueError("Unsupported model in Torch model zoo.") + + # Setup input + input_data = torch.randn(input_shape).type(torch.float32) + shape_list = [("input0", input_shape)] + + # Get trace. Depending on the model type, wrapper may be necessary. + if MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + scripted_model = do_trace(wrapped_model, input_data) + + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + + def dict_to_tuple(out_dict): + if "masks" in out_dict.keys(): + return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"] + return out_dict["boxes"], out_dict["scores"], out_dict["labels"] + + class TraceWrapper(torch.nn.Module): # type: ignore + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return dict_to_tuple(out[0]) + + wrapped_model = TraceWrapper(model) + wrapped_model.eval() + with torch.no_grad(): + _ = wrapped_model(input_data) + scripted_model = do_trace(wrapped_model, input_data) + else: + scripted_model = do_trace(model, input_data) + + # Convert torch model to relay module + mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) + return mod, params + def get_network( name: str, @@ -30,6 +195,8 @@ def get_network( dtype: str = "float32", ) -> Tuple[IRModule, Dict[str, NDArray], Tuple[int, int, int, int], Tuple[int, int]]: """Get the symbol definition and random weight of a network""" + import tvm.relay.testing # pylint: disable=import-outside-toplevel,unused-import + # meta-schedule prefers NHWC layout if layout == "NHWC": image_shape = (224, 224, 3) diff --git a/python/tvm/meta_schedule/testing/te_workload.py b/python/tvm/meta_schedule/testing/te_workload.py new file mode 100644 index 000000000000..e146750e259b --- /dev/null +++ b/python/tvm/meta_schedule/testing/te_workload.py @@ -0,0 +1,744 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Workloads in TE""" +from typing import Tuple +from tvm import te, tir, topi + + +def batch_matmul_nkkm( # pylint: disable=invalid-name,missing-docstring + B: int, + N: int, + M: int, + K: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + x = te.placeholder((B, N, K), name="X") + y = te.placeholder((B, K, M), name="Y") + k = te.reduce_axis((0, K), name="k") + z = te.compute( # pylint: disable=invalid-name + (B, N, M), + lambda b, i, j: te.sum(x[b][i][k] * y[b][k][j], axis=[k]), + name="Z", + ) + return (x, y, z) + + +def conv1d_nlc( # pylint: disable=invalid-name,missing-docstring + N: int, + L: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, L, CI), name="inputs") + weight = te.placeholder((kernel_size, CI // groups, CO), name="weight") + + batch_size, in_len, _ = inputs.shape + k_len, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + out_len = (in_len + 2 * padding - dilation * (k_len - 1) - 1) // stride + 1 + rc = te.reduce_axis((0, channel_per_group), name="rc") + rl = te.reduce_axis((0, k_len), name="rl") + + padded = topi.nn.pad(inputs, [0, padding, 0]) + output = te.compute( + (batch_size, out_len, out_channel), + lambda n, l, co: te.sum( + ( + padded[ + n, + l * stride + rl * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rl, rc, co] + ), + axis=[rl, rc], + ), + name="conv1d_nlc", + ) + return (inputs, weight, output) + + +def conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI // groups, CO), name="weight") + batch_size, in_h, in_w, _ = inputs.shape + k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, co: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rh, rw, rc, co] + ), + axis=[rh, rw, rc], + ), + name="conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv3d_ndhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + D: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, D, H, W, CI)) + weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI // groups, CO)) + batch_size, in_d, in_h, in_w, _ = inputs.shape + k_d, k_h, k_w, channel_per_group, out_channel = weight.shape + out_channel_per_group = out_channel // groups + + out_d = (in_d + 2 * padding - dilation * (k_d - 1) - 1) // stride + 1 + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rd = te.reduce_axis((0, k_d), name="rd") + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + rc = te.reduce_axis((0, channel_per_group), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, padding, 0]) + output = te.compute( + (batch_size, out_d, out_h, out_w, out_channel), + lambda n, d, h, w, co: te.sum( + ( + padded[ + n, + d * stride + rd * dilation, + h * stride + rh * dilation, + w * stride + rw * dilation, + co // out_channel_per_group * channel_per_group + rc, + ] + * weight[rd, rh, rw, rc, co] + ), + axis=[rd, rh, rw, rc], + ), + name="conv3d_ndhwc", + ) + return (inputs, weight, output) + + +def depthwise_conv2d_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + C: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + factor: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, C)) + weight = te.placeholder((factor, kernel_size, kernel_size, C)) + batch_size, in_h, in_w, in_channel = inputs.shape + factor, k_h, k_w, in_channel = weight.shape + out_channel = in_channel * factor + assert int(factor) == 1, "Not optimized for factor != 1" + out_h = (in_h + 2 * padding - dilation * (k_h - 1) - 1) // stride + 1 + out_w = (in_w + 2 * padding - dilation * (k_w - 1) - 1) // stride + 1 + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + padded = topi.nn.pad(inputs, [0, padding, padding, 0]) + output = te.compute( + (batch_size, out_h, out_w, out_channel), + lambda n, h, w, c: te.sum( + ( + padded[ + n, + h * stride + rh * dilation, + w * stride + rw * dilation, + c // factor, + ] + * weight[c % factor, rh, rw, c // factor] + ), + axis=[rh, rw], + ), + name="depth_conv2d_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_transpose_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, CI), name="inputs") + weight = te.placeholder((kernel_size, kernel_size, CI, CO), name="weight") + + batch, in_h, in_w, in_c = inputs.shape + filter_h, filter_w, in_c, out_c = weight.shape + stride_h, stride_w = (stride, stride) + + # compute padding + fpad_top, fpad_left, fpad_bottom, fpad_right = topi.nn.get_pad_tuple( + padding, (filter_h, filter_w) + ) + bpad_top = filter_h - 1 - fpad_top + bpad_bottom = filter_h - 1 - fpad_bottom + bpad_left = filter_w - 1 - fpad_left + bpad_right = filter_w - 1 - fpad_right + + # padding stage + padded = topi.nn.pad( + inputs, + [ + 0, + (bpad_top + stride_h - 1) // stride_h, + (bpad_left + stride_w - 1) // stride_w, + 0, + ], + [ + 0, + (bpad_bottom + stride_h - 1) // stride_h, + (bpad_right + stride_w - 1) // stride_w, + 0, + ], + ) + + # remove extra padding introduced by dilatation + idx_div = te.indexdiv + idx_mod = te.indexmod + border_h = idx_mod(stride_h - idx_mod(bpad_top, stride_h), stride_h) + border_w = idx_mod(stride_w - idx_mod(bpad_left, stride_w), stride_w) + + # dilation stage + strides = [1, stride_h, stride_w, 1] + n = len(padded.shape) + + # We should embed this dilation directly into te.compute rather than creating a new te.compute. + # Only in this way can we use unroll to eliminate the multiplication of zeros. + def _dilate(*indices): + not_zero = [] + index_tuple = [] + for i in range(n): + if not strides[i] == 1: + index_tuple.append(idx_div(indices[i], strides[i])) + not_zero.append(idx_mod(indices[i], strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = te.all(*not_zero) + return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return padded(*index_tuple) + + # convolution stage + out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + rc = te.reduce_axis((0, in_c), name="rc") + rh = te.reduce_axis((0, filter_h), name="rh") + rw = te.reduce_axis((0, filter_w), name="rw") + + output = te.compute( + (batch, out_h, out_w, out_c), + lambda n, h, w, co: te.sum( + _dilate(n, h + rh + border_h, w + rw + border_w, rc) + * weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], + axis=[rh, rw, rc], + ), + name="conv2d_transpose_nhwc", + ) + return (inputs, weight, output) + + +def conv2d_capsule_nhwijc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + capsule_size: int = 4, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name="inputs") + weight = te.placeholder( + (kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name="weight" + ) + batch_size, in_h, in_w, _, _, in_channel = inputs.shape + k_h, k_w, _, _, _, out_channel = weight.shape + + out_h = (in_h + 2 * padding - kernel_size) // stride + 1 + out_w = (in_w + 2 * padding - kernel_size) // stride + 1 + + rh = te.reduce_axis((0, k_h), name="rh") + rw = te.reduce_axis((0, k_w), name="rw") + cap_k = te.reduce_axis((0, capsule_size), name="cap_k") + rc = te.reduce_axis((0, in_channel), name="rc") + + padded = topi.nn.pad(inputs, [0, padding, padding, 0, 0, 0]) + output = te.compute( + (batch_size, out_h, out_w, capsule_size, capsule_size, out_channel), + lambda n, h, w, cap_i, cap_j, co: te.sum( + ( + padded[n, h * stride + rh, w * stride + rw, cap_i, cap_k, rc] + * weight[rh, rw, cap_k, cap_j, rc, co] + ), + axis=[rh, rw, cap_k, rc], + ), + name="conv2d_capsule_nhwijc", + ) + return (inputs, weight, output) + + +def norm_bmn( # pylint: disable=invalid-name,missing-docstring + B: int, + M: int, + N: int, +) -> Tuple[te.Tensor, te.Tensor]: + a = te.placeholder((B, M, N), name="A") + i = te.reduce_axis((0, M), name="i") + j = te.reduce_axis((0, N), name="j") + c = te.compute( + (B,), + lambda b: te.sum(a[b][i][j] * a[b][i][j], axis=[i, j]), + name="C", + ) + d = te.compute((B,), lambda b: te.sqrt(c[b]), name="D") + return (a, d) + + +def conv2d_nhwc_without_layout_rewrite( # pylint: disable=invalid-name + Input: int, + Filter: int, + stride: int, + padding: int, + dilation: int, + out_dtype="float32", +): + """A copy of `topi.nn.conv2d_nhwc` but without the 'layout_free` attribute. + We use this in single op and subgraph evaluation + because we don't want to introduce graph level optimization. + """ + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape # type: ignore + kernel_h, kernel_w, _channel, num_filter = Filter.shape # type: ignore + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = topi.nn.get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + out_channel = num_filter + out_height = topi.utils.simplify( + (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1 + ) + out_width = topi.utils.simplify( + (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1 + ) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = topi.nn.pad(Input, pad_before, pad_after, name="PaddedInput") + rc = te.reduce_axis((0, in_channel), name="rc") + ry = te.reduce_axis((0, kernel_h), name="ry") + rx = te.reduce_axis((0, kernel_w), name="rx") + Output = te.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: te.sum( + PaddedInput[ + nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rc + ].astype(out_dtype) + * Filter[ry, rx, rc, ff].astype(out_dtype), # type: ignore + axis=[ry, rx, rc], + ), + name="Conv2dOutput", + tag="conv2d_nhwc", + ) + return Output + + +def conv2d_nhwc_bn_relu( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + strides: int, + padding: int, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor, te.Tensor]: + data = te.placeholder((N, H, W, CI), name="data") + kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name="kernel") + bias = te.placeholder((CO,), name="bias") + bn_scale = te.placeholder((CO,), name="bn_scale") + bn_offset = te.placeholder((CO,), name="bn_offset") + OH = (H + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + OW = (W + 2 * padding - (kernel_size - 1) * dilation - 1) // strides + 1 + conv = conv2d_nhwc_without_layout_rewrite(data, kernel, strides, padding, dilation) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bias[l], name="bias_add" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] * bn_scale[l], name="bn_mul" + ) + conv = te.compute( + (N, OH, OW, CO), lambda i, j, k, l: conv[i, j, k, l] + bn_offset[l], name="bn_add" + ) + out = topi.nn.relu(conv) + return (data, kernel, bias, bn_offset, bn_scale, out) + + +def transpose_batch_matmul( # pylint: disable=invalid-name,missing-docstring + batch: int, + seq_len: int, + n_head: int, + n_dim: int, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + query = te.placeholder((batch, seq_len, n_head, n_dim), name="query") + value = te.placeholder((batch, seq_len, n_head, n_dim), name="value") + query_T = te.compute( + (batch, n_head, seq_len, n_dim), + lambda b, h, l, d: query[b, l, h, d], + name="query_T", + ) + value_T = te.compute( + (batch, n_head, n_dim, seq_len), + lambda b, h, d, l: value[b, l, h, d], + name="value_T", + ) + k = te.reduce_axis((0, n_dim), name="k") + out = te.compute( + (batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b, h, i, k] * value_T[b, h, k, j], axis=[k]), + name="C", + ) + return (query, value, out) + + +def conv2d_winograd_nhwc( # pylint: disable=invalid-name,missing-docstring + N: int, + H: int, + W: int, + CI: int, + CO: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, +) -> Tuple[te.Tensor, te.Tensor, te.Tensor]: + tile_size = 4 # _infer_tile_size(data, kernel) + inputs = te.placeholder((N, H, W, CI), name="inputs") + N, H, W, CI = topi.utils.get_const_tuple(inputs.shape) + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" + + KH = KW = kernel_size + HPAD, WPAD, _, _ = topi.nn.get_pad_tuple(padding, (KH, KW)) + HSTR, WSTR = (stride, stride) if isinstance(stride, int) else stride + assert HSTR == 1 and WSTR == 1 and KH == KW + + data_pad = topi.nn.pad(inputs, (0, HPAD, WPAD, 0), (0, HPAD, WPAD, 0), name="data_pad") + + r = KW + m = tile_size + alpha = m + r - 1 + A, B, _G = topi.nn.winograd_util.winograd_transform_matrices(m, r, "float32") + + H = (H + 2 * HPAD - KH) // HSTR + 1 + W = (W + 2 * WPAD - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + P = N * nH * nW + _rkh = te.reduce_axis((0, KH), name="r_kh") + _rkw = te.reduce_axis((0, KW), name="r_kw") + kshape = (alpha, alpha, CI, CO) + kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight") + + idxdiv = te.indexdiv + idxmod = te.indexmod + # pack input tile + input_tile = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: data_pad[idxdiv(p, (nH * nW))][idxmod(idxdiv(p, nW), nH) * m + eps][ + idxmod(p, nW) * m + nu + ][ci], + name="input_tile", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + data_pack = te.compute( + (alpha, alpha, P, CI), + lambda eps, nu, p, ci: te.sum( + input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + bgemm = te.compute( + (alpha, alpha, P, CO), + lambda eps, nu, p, co: te.sum( + data_pack[eps][nu][p][ci] * kernel_pack[eps][nu][ci][co], axis=[ci] + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_b") + inverse = te.compute( + (m, m, P, CO), + lambda vh, vw, p, co: te.sum( + bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] + ), + name="inverse", + attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + ) + + # output + output = te.compute( + (N, H, W, CO), + lambda n, h, w, co: inverse[ + idxmod(h, m), idxmod(w, m), n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m), co + ], + name="conv2d_winograd", + ) + + return (inputs, kernel_pack, output) + + + +def create_te_workload(name: str, idx: int) -> tir.PrimFunc: + workload_func, params = CONFIGS[name] + return te.create_prim_func(workload_func(*params[idx])) # type: ignore + + +CONFIGS = { + "C1D": ( + conv1d_nlc, + [ + # derived from conv2d_shapes + (1, 256, 64, 128, 3, 2, 1), + # (1, 256, 64, 128, 1, 2, 0), + # (1, 256, 64, 64, 1, 1, 0), + # (1, 128, 128, 256, 3, 2, 1), + (1, 128, 128, 256, 1, 2, 0), + # (1, 128, 128, 128, 3, 1, 1), + # (1, 64, 256, 512, 3, 2, 1), + # (1, 64, 256, 512, 1, 2, 0), + (1, 64, 256, 256, 5, 1, 2), + (1, 32, 512, 512, 3, 1, 1), + ], + ), + "C2D": ( + conv2d_nhwc, + [ + # all conv2d layers in resnet-18 + (1, 224, 224, 3, 64, 7, 2, 3), + # (1, 56, 56, 64, 128, 3, 2, 1), + # (1, 56, 56, 64, 128, 1, 2, 0), + # (1, 56, 56, 64, 64, 3, 1, 1), + (1, 56, 56, 64, 64, 1, 1, 0), + # (1, 28, 28, 128, 256, 3, 2, 1), + # (1, 28, 28, 128, 256, 1, 2, 0), + # (1, 28, 28, 128, 128, 3, 1, 1), + # (1, 14, 14, 256, 512, 3, 2, 1), + # (1, 14, 14, 256, 512, 1, 2, 0), + (1, 14, 14, 256, 256, 3, 1, 1), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "C3D": ( + conv3d_ndhwc, + [ + # Derived from conv2d_shapes. Use depth=16 for all configurations + (1, 16, 224, 224, 3, 64, 7, 2, 3), + # (1, 16, 56, 56, 64, 128, 3, 2, 1), + # (1, 16, 56, 56, 64, 128, 1, 2, 0), + # (1, 16, 56, 56, 64, 64, 3, 1, 1), + (1, 16, 56, 56, 64, 64, 1, 1, 0), + # (1, 16, 28, 28, 128, 256, 3, 2, 1), + # (1, 16, 28, 28, 128, 256, 1, 2, 0), + # (1, 16, 28, 28, 128, 128, 3, 1, 1), + # (1, 16, 14, 14, 256, 512, 3, 2, 1), + # (1, 16, 14, 14, 256, 512, 1, 2, 0), + (1, 16, 14, 14, 256, 256, 3, 1, 1), + (1, 16, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "GMM": ( + batch_matmul_nkkm, + [ + (1, 128, 128, 128), + (1, 512, 32, 512), + (1, 512, 512, 512), + (1, 1024, 1024, 1024), + ], + ), + "GRP": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use group=4 for all configurations + (1, 56, 56, 64, 128, 3, 2, 1, 1, 4), + # (1, 56, 56, 64, 128, 1, 2, 0 , 1, 4), + # (1, 56, 56, 64, 64, 3, 1, 1 , 1, 4), + (1, 56, 56, 64, 64, 1, 1, 0, 1, 4), + # (1, 28, 28, 128, 256, 3, 2, 1, 1, 4), + # (1, 28, 28, 128, 256, 1, 2, 0, 1, 4), + # (1, 28, 28, 128, 128, 3, 1, 1, 1, 4), + # (1, 14, 14, 256, 512, 3, 2, 1, 1, 4), + # (1, 14, 14, 256, 512, 1, 2, 0, 1, 4), + (1, 14, 14, 256, 256, 3, 1, 1, 1, 4), + (1, 7, 7, 512, 512, 3, 1, 1, 1, 4), + ], + ), + "DIL": ( + conv2d_nhwc, + [ + # Derived from conv2d_shapes. Use dilation=2 for all configurations + (1, 224, 224, 3, 64, 7, 2, 3, 2), + # (1, 56, 56, 64, 128, 3, 2, 1 , 2), + # (1, 56, 56, 64, 128, 1, 2, 0 , 2), + # (1, 56, 56, 64, 64, 3, 1, 1 , 2), + (1, 56, 56, 64, 64, 1, 1, 0, 2), + # (1, 28, 28, 128, 256, 3, 2, 1, 2), + # (1, 28, 28, 128, 256, 1, 2, 0, 2), + # (1, 28, 28, 128, 128, 3, 1, 1, 2), + # (1, 14, 14, 256, 512, 3, 2, 1, 2), + # (1, 14, 14, 256, 512, 1, 2, 0, 2), + (1, 14, 14, 256, 256, 3, 1, 1, 2), + (1, 7, 7, 512, 512, 3, 1, 1, 2), + ], + ), + "DEP": ( + depthwise_conv2d_nhwc, + [ + # all depthwise conv2d layers in mobilenet + (1, 112, 112, 32, 3, 1, 1), + (1, 112, 112, 64, 3, 2, 1), + # (1, 56, 56, 128, 3, 1, 1), + # (1, 56, 56, 128, 3, 2, 1), + # (1, 28, 28, 256, 3, 1, 1), + # (1, 28, 28, 256, 3, 2, 1), + # (1, 14, 14, 512, 3, 1, 1), + (1, 14, 14, 512, 3, 2, 1), + (1, 7, 7, 1024, 3, 1, 1), + ], + ), + "T2D": ( + conv2d_transpose_nhwc, + [ + # all conv2d tranpose layers in DCGAN + (1, 4, 4, 512, 256, 4, 2, 1), + (1, 8, 8, 256, 128, 4, 2, 1), + (1, 16, 16, 128, 64, 4, 2, 1), + (1, 32, 32, 64, 3, 4, 2, 1), + ], + ), + "CAP": ( + conv2d_capsule_nhwijc, + [ + # all conv2d capsule layers in matrix capsules withemrouting (ICLR 2018) + (1, 16, 16, 32, 32, 3, 2, 1), + (1, 8, 8, 32, 32, 3, 1, 1), + (1, 16, 16, 8, 16, 3, 2, 1), + (1, 8, 8, 16, 16, 3, 1, 1), + ], + ), + "NRM": ( + norm_bmn, + [ + (1, 256, 256), + (1, 512, 512), + (1, 1024, 1024), + (1, 4096, 1024), + ], + ), + "C2d-BN-RELU": ( + conv2d_nhwc_bn_relu, + [ + (1, 224, 224, 3, 64, 7, 2, 3), + (1, 56, 56, 64, 128, 3, 2, 1), + (1, 28, 28, 128, 256, 1, 2, 0), + (1, 7, 7, 512, 512, 3, 1, 1), + ], + ), + "TBG": ( + transpose_batch_matmul, + [ + (1, 128, 12, 64), + (1, 128, 16, 64), + (1, 64, 12, 128), + (1, 128, 12, 128), + ], + ), +} + diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 0f3cfac1a85f..af219086395a 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,7 +16,7 @@ # under the License. """Meta Schedule tuning context.""" -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, List from tvm import IRModule from tvm._ffi import register_object @@ -29,6 +29,9 @@ if TYPE_CHECKING: from .space_generator import SpaceGenerator from .search_strategy import SearchStrategy + from .schedule_rule import ScheduleRule + from .postproc import Postproc + from .mutator import Mutator @register_object("meta_schedule.TuneContext") @@ -50,6 +53,12 @@ class TuneContext(Object): The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules : List[ScheduleRule] = [] + The schedule rules. + postproc : List[Postproc] = [] + The post processings. + mutator : List[Mutator] = [] + The mutators. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -68,8 +77,11 @@ class TuneContext(Object): mod: Optional[IRModule] target: Optional[Target] - space_generator: "SpaceGenerator" - search_strategy: "SearchStrategy" + space_generator: Optional["SpaceGenerator"] + search_strategy: Optional["SearchStrategy"] + sch_rules: List["ScheduleRule"] + postproc: List["Postproc"] + mutator: List["Mutator"] task_name: Optional[str] rand_state: int num_threads: int @@ -80,6 +92,9 @@ def __init__( target: Optional[Target] = None, space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, + sch_rules: List["ScheduleRule"] = [], + postproc: List["Postproc"] = [], + mutator: List["Mutator"] = [], task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, @@ -96,6 +111,12 @@ def __init__( The design space generator. search_strategy : Optional[SearchStrategy] = None The search strategy. + sch_rules : List[ScheduleRule] = [] + The schedule rules. + postproc : List[Postproc] = [] + The post processings. + mutator : List[Mutator] = [] + The mutators. task_name : Optional[str] = None The name of the tuning task. rand_state : int = -1 @@ -113,6 +134,9 @@ def __init__( target, space_generator, search_strategy, + sch_rules, + postproc, + mutator, task_name, rand_state, num_threads, diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index 3deac405d2c1..22fc2cc5be49 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Utilities for meta schedule""" +import ctypes import json import os import shutil @@ -207,6 +208,22 @@ def structural_hash(mod: IRModule) -> str: return str(shash) +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + def check_override( derived_class: Any, base_class: Any, required: bool = True, func_name: str = None ) -> Callable: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 26632997dd57..3bb56aeab004 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -400,6 +400,21 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: """ return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member + def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]: + """Get the leaf blocks of a specific block/loop + + Parameters + ---------- + block_or_loop : Union[BlockRV, LoopRV] + The query block/loop + + Returns + ------- + blocks : List[LoopRV] + A list of leaf blocks inside a specific block/loop + """ + return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # pylint: disable=no-member + ########## Schedule: Transform loops ########## def fuse(self, *loops: List[LoopRV]) -> LoopRV: """Fuse a list of consecutive loops into one. It requires: @@ -959,6 +974,30 @@ def after_cache_write(a: T.handle, b: T.handle) -> None: self, block, write_buffer_index, storage_scope ) + ########## Schedule: Data movement ########## + + def read_at( + self, + loop: LoopRV, + block: BlockRV, + read_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member + self, loop, block, read_buffer_index, storage_scope + ) + + def write_at( + self, + loop: LoopRV, + block: BlockRV, + write_buffer_index: int, + storage_scope: str, + ) -> BlockRV: + return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member + self, loop, block, write_buffer_index, storage_scope + ) + ########## Schedule: Compute location ########## def compute_at( diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc new file mode 100644 index 000000000000..733d118c735d --- /dev/null +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +MeasureCallback MeasureCallback::PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, // + PyMeasureCallbackNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return MeasureCallback(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMeasureCallbackNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMeasureCallback's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); +TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackApply") + .set_body_method(&MeasureCallbackNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MeasureCallbackPyMeasureCallback") + .set_body_typed(MeasureCallback::PyMeasureCallback); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc new file mode 100644 index 000000000000..9bf6161b5507 --- /dev/null +++ b/src/meta_schedule/mutator/mutator.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Mutator Mutator::PyMutator( + PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyMutatorNode::FApply f_apply, // + PyMutatorNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Mutator(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyMutatorNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyMutator's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(MutatorNode); +TVM_REGISTER_NODE_TYPE(PyMutatorNode); + +TVM_REGISTER_GLOBAL("meta_schedule.MutatorInitializeWithTuneContext") + .set_body_method(&MutatorNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorApply").set_body_method(&MutatorNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.MutatorPyMutator").set_body_typed(Mutator::PyMutator); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc new file mode 100644 index 000000000000..ff069e2c68cb --- /dev/null +++ b/src/meta_schedule/postproc/postproc.cc @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +Postproc Postproc::PyPostproc( + PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyPostprocNode::FApply f_apply, // + PyPostprocNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return Postproc(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyPostprocNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyPostproc's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(PostprocNode); +TVM_REGISTER_NODE_TYPE(PyPostprocNode); + +TVM_REGISTER_GLOBAL("meta_schedule.PostprocInitializeWithTuneContext") + .set_body_method(&PostprocNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocApply").set_body_method(&PostprocNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocPyPostproc").set_body_typed(Postproc::PyPostproc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc new file mode 100644 index 000000000000..f80f684dafa8 --- /dev/null +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +ScheduleRule ScheduleRule::PyScheduleRule( + PyScheduleRuleNode::FInitializeWithTuneContext f_initialize_with_tune_context, // + PyScheduleRuleNode::FApply f_apply, // + PyScheduleRuleNode::FAsString f_as_string) { + ObjectPtr n = make_object(); + n->f_initialize_with_tune_context = std::move(f_initialize_with_tune_context); + n->f_apply = std::move(f_apply); + n->f_as_string = std::move(f_as_string); + return ScheduleRule(n); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& n, ReprPrinter* p) { + const auto* self = n.as(); + ICHECK(self); + PyScheduleRuleNode::FAsString f_as_string = (*self).f_as_string; + ICHECK(f_as_string != nullptr) << "PyScheduleRule's AsString method not implemented!"; + p->stream << f_as_string(); + }); + +TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); +TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); + +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInitializeWithTuneContext") + .set_body_method(&ScheduleRuleNode::InitializeWithTuneContext); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApply") + .set_body_method(&ScheduleRuleNode::Apply); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRulePyScheduleRule") + .set_body_typed(ScheduleRule::PyScheduleRule); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc new file mode 100644 index 000000000000..5c00b3dc04c4 --- /dev/null +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief A search strategy that generates measure candidates using space generator. */ +class ReplayFuncNode : public SearchStrategyNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The state of the search strategy. */ + struct State { + /*! \brief The search strategy itself */ + ReplayFuncNode* self; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int st; + /*! \brief `[st, ed)` are the indices of the next batch of candidates. */ + int ed; + + explicit State(ReplayFuncNode* self) : self(self), st(0), ed(self->num_trials_per_iter) {} + + inline Optional> GenerateMeasureCandidates(); + inline void NotifyRunnerResults(const Array& results); + }; + + /*! \brief The number of trials per iteration. */ + int num_trials_per_iter; + /*! \brief The number of total trials. */ + int num_trials_total; + + /*! \brief The module to be tuned. */ + IRModule mod_{nullptr}; + /*! \brief The metadata of the function arguments. */ + Array args_info_{nullptr}; + /*! \brief The space generator for measure candidates generation. */ + SpaceGenerator space_generator_{nullptr}; + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The state of the search strategy. */ + std::unique_ptr state_ = nullptr; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("num_trials_per_iter", &num_trials_per_iter); + v->Visit("num_trials_total", &num_trials_total); + // `space_generator_` is not visited + // `mod_` is not visited + // `args_info_` is not visited + // `num_threads_` is not visited + // `rand_state_` is not visited + // `state_` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->space_generator_ = tune_context->space_generator.value(); + this->mod_ = tune_context->mod.value(); + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->state_.reset(); + } + + void PreTuning(const Array& design_spaces) final { + ICHECK(this->state_ == nullptr); + this->state_ = std::make_unique(this); + } + + void PostTuning() final { + ICHECK(this->state_ != nullptr); + this->state_.reset(); + } + + Optional> GenerateMeasureCandidates() final { + ICHECK(this->state_ != nullptr); + return this->state_->GenerateMeasureCandidates(); + } + + void NotifyRunnerResults(const Array& results) final { + ICHECK(this->state_ != nullptr); + this->state_->NotifyRunnerResults(results); + } +}; + +inline Optional> ReplayFuncNode::State::GenerateMeasureCandidates() { + if (st >= self->num_trials_total) { + return NullOpt; + } + ed = std::min(ed, self->num_trials_total); + Array result; + for (int i = st; i < ed; i++) { + Array schs = self->space_generator_->GenerateDesignSpace(self->mod_); + result.push_back(MeasureCandidate(schs[tir::SampleInt(&self->rand_state_, 0, schs.size())], + self->args_info_)); + } + return result; +} + +inline void ReplayFuncNode::State::NotifyRunnerResults(const Array& results) { + st += self->num_trials_per_iter; + ed += self->num_trials_per_iter; +} + +SearchStrategy SearchStrategy::ReplayFunc(int num_trials_per_iter, int num_trials_total) { + ObjectPtr n = make_object(); + n->num_trials_per_iter = num_trials_per_iter; + n->num_trials_total = num_trials_total; + return SearchStrategy(n); +} + +TVM_REGISTER_NODE_TYPE(ReplayFuncNode); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") + .set_body_typed(SearchStrategy::ReplayFunc); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1c83aee8c0fd..c4ee3c467921 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -50,7 +50,7 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_total; /*! \brief The module to be tuned. */ - IRModule mod_{nullptr}; + Array mod_{nullptr}; /*! \brief The metadata of the function arguments. */ Array args_info_{nullptr}; /*! \brief The number of threads to use. -1 means using logical cpu number. */ @@ -74,9 +74,15 @@ class ReplayTraceNode : public SearchStrategyNode { TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); void InitializeWithTuneContext(const TuneContext& tune_context) final { - this->mod_ = tune_context->mod.value(); - this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(this->mod_)); + CHECK(tune_context->num_threads > 0) << "Number of threads has to be larger than 0."; this->num_threads_ = tune_context->num_threads; + + this->mod_.reserve(this->num_threads_); + for (int i = 0; i < this->num_threads_; i++) { + this->mod_.push_back(DeepCopyIRModule(tune_context->mod.value())); + } + + this->args_info_ = ArgInfo::FromPrimFunc(FindEntryFunc(tune_context->mod.value())); this->rand_state_ = ForkSeed(&tune_context->rand_state); this->state_.reset(); } @@ -118,7 +124,7 @@ inline Optional> ReplayTraceNode::State::GenerateMeasure tir::Trace trace = design_spaces[design_space_index]->trace().value(); tir::Trace new_trace = tir::Trace(trace->insts, {}); tir::Schedule sch = tir::Schedule::Traced( // - self->mod_, // + self->mod_[thread_id], // /*rand_state=*/ForkSeed(&rand_state), // /*debug_mode=*/0, // /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); @@ -142,7 +148,8 @@ SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int num_tria } TVM_REGISTER_NODE_TYPE(ReplayTraceNode); -TVM_REGISTER_GLOBAL("meta_schedule.ReplayTrace").set_body_typed(SearchStrategy::ReplayTrace); +TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") + .set_body_typed(SearchStrategy::ReplayTrace); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc new file mode 100644 index 000000000000..41afbc57d79b --- /dev/null +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Collecting all the non-root blocks */ +class BlockCollector : public tir::StmtVisitor { + public: + static Array Collect(const tir::Schedule& sch) { // + return BlockCollector(sch).Run(); + } + + private: + /*! \brief Entry point */ + Array Run() { + for (const auto& kv : sch_->mod()->functions) { + const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function + const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + func_name_ = gv->name_hint; + block_names_.clear(); + blocks_to_collect_.clear(); + root_block_ = func->body.as()->block.get(); + VisitStmt(func->body); + for (const String& block_name : blocks_to_collect_) { + results_.push_back(sch_->GetBlock(block_name, func_name_)); + } + } + } + return results_; + } + /*! \brief Constructor */ + explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + /*! \brief Override the Stmt visiting behaviour */ + void VisitStmt_(const tir::BlockNode* block) override { + tir::StmtVisitor::VisitStmt_(block); + if (block != root_block_) { + CHECK(block_names_.count(block->name_hint) == 0) + << "Duplicated block name " << block->name_hint << " in function " << func_name_ + << " not supported!"; + block_names_.insert(block->name_hint); + blocks_to_collect_.push_back(block->name_hint); + } + } + + /*! \brief The schedule to be collected */ + const tir::Schedule& sch_; + /*! \brief The set of func name and block name pair */ + std::unordered_set block_names_; + /* \brief The list of blocks to collect in order */ + Array blocks_to_collect_; + /*! \brief Function name & blocks of collection */ + Array results_; + /*! \brief The root block of the PrimFunc */ + const tir::BlockNode* root_block_; + /*! \brief Name of the current PrimFunc */ + String func_name_; +}; + +/*! + * \brief Design Space Generator that generates design spaces by applying schedule rules to blocks + * in post-DFS order. + * */ +class PostOrderApplyNode : public SpaceGeneratorNode { + public: + using TRandState = support::LinearCongruentialEngine::TRandState; + + /*! \brief The random state. -1 means using random number. */ + TRandState rand_state_ = -1; + /*! \brief The schedule rules to be applied in order. */ + Array sch_rules_{nullptr}; + + void VisitAttrs(tvm::AttrVisitor* v) { + // `rand_state_` is not visited + // `sch_rules_` is not visited + } + + void InitializeWithTuneContext(const TuneContext& tune_context) final { + this->rand_state_ = ForkSeed(&tune_context->rand_state); + this->sch_rules_ = tune_context->sch_rules; + } + + Array GenerateDesignSpace(const IRModule& mod_) final { + using ScheduleAndUnvisitedBlocks = std::pair>; + tir::Schedule sch = tir::Schedule::Traced( // + /*mod=*/mod_, // + /*rand_state=*/ForkSeed(&this->rand_state_), // + /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail // + ); + + std::vector stack; + Array result{sch}; + // Enumerate the schedule rules first because you can + // always concat multiple schedule rules as one + for (ScheduleRule sch_rule : sch_rules_) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, BlockCollector::Collect(sch)); + } + result.clear(); + + while (!stack.empty()) { + // get the stack.top() + tir::Schedule sch; + Array blocks; + std::tie(sch, blocks) = stack.back(); + stack.pop_back(); + // if all blocks are visited + if (blocks.empty()) { + result.push_back(sch); + continue; + } + // otherwise, get the last block that is not visited + tir::BlockRV block_rv = blocks.back(); + blocks.pop_back(); + if (sch->HasBlock(block_rv)) { + Array applied = sch_rule->Apply(sch, /*block=*/block_rv); + for (const tir::Schedule& sch : applied) { + stack.emplace_back(sch, blocks); + } + } + } + } + return result; + } + static constexpr const char* _type_key = "meta_schedule.PostOrderApply"; + TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); +}; + +SpaceGenerator SpaceGenerator::PostOrderApply() { + ObjectPtr n = make_object(); + return SpaceGenerator(n); +} + +TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); +TVM_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") + .set_body_typed(SpaceGenerator::PostOrderApply); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 3ef5026cae98..2bd7cf4bcfca 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -55,12 +55,14 @@ class RoundRobinNode final : public TaskSchedulerNode { TaskScheduler TaskScheduler::RoundRobin(Array tasks, // Builder builder, // Runner runner, // - Database database) { + Database database, // + Array measure_callbacks) { ObjectPtr n = make_object(); n->tasks = tasks; n->builder = builder; n->runner = runner; n->database = database; + n->measure_callbacks = measure_callbacks; n->task_id = -1; return TaskScheduler(n); } diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 08f2b4f451bd..3bd2ab9ddeea 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -101,6 +101,16 @@ void TaskSchedulerNode::InitializeTask(int task_id) { // Initialize Modules. space->InitializeWithTuneContext(task); strategy->InitializeWithTuneContext(task); + // Initialize the rules. + for (const ScheduleRule& sch_rule : task->sch_rules) { + sch_rule->InitializeWithTuneContext(task); + } + for (const Mutator& mutator : task->mutators) { + mutator->InitializeWithTuneContext(task); + } + for (const Postproc& postproc : task->postprocs) { + postproc->InitializeWithTuneContext(task); + } } void TaskSchedulerNode::Tune() { @@ -201,6 +211,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( Builder builder, // Runner runner, // Database database, // + Array measure_callbacks, // PyTaskSchedulerNode::FTune f_tune, // PyTaskSchedulerNode::FInitializeTask f_initialize_task, // PyTaskSchedulerNode::FSetTaskStopped f_set_task_stopped, // @@ -212,6 +223,7 @@ TaskScheduler TaskScheduler::PyTaskScheduler( n->builder = builder; n->runner = runner; n->database = database; + n->measure_callbacks = measure_callbacks; n->f_tune = f_tune; n->f_initialize_task = f_initialize_task; n->f_set_task_stopped = f_set_task_stopped; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 9fc9272e33ac..21ba8294fca4 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -24,20 +24,13 @@ namespace tvm { namespace meta_schedule { -/*! - * \brief Constructor function of TuneContext class. - * \param mod The mod to be optimized. - * \param target The target to be optimized for. - * \param space_generator The design space generator. - * \param task_name The name of the tuning task. - * \param rand_state The random state. - * \param num_threads The number of threads to be used. - * \param verbose The verbosity level. - */ TuneContext::TuneContext(Optional mod, // Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) { @@ -46,6 +39,9 @@ TuneContext::TuneContext(Optional mod, n->target = target; n->space_generator = space_generator; n->search_strategy = search_strategy; + n->sch_rules = sch_rules; + n->postprocs = postprocs; + n->mutators = mutators; n->task_name = task_name; if (rand_state == -1) { rand_state = std::random_device()(); @@ -65,11 +61,14 @@ TVM_REGISTER_GLOBAL("meta_schedule.TuneContext") Optional target, // Optional space_generator, // Optional search_strategy, // + Array sch_rules, // + Array postprocs, // + Array mutators, // Optional task_name, // support::LinearCongruentialEngine::TRandState rand_state, // int num_threads) -> TuneContext { - return TuneContext(mod, target, space_generator, search_strategy, task_name, rand_state, - num_threads); + return TuneContext(mod, target, space_generator, search_strategy, sch_rules, postprocs, + mutators, task_name, rand_state, num_threads); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 83e65a5ced44..be76d3e8db98 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -23,7 +23,11 @@ #include #include #include +#include +#include +#include #include +#include #include #include #include @@ -32,6 +36,7 @@ #include #include #include +#include #include #include @@ -193,7 +198,7 @@ inline support::LinearCongruentialEngine::TRandState ForkSeed( /*! * \brief Fork a random state into another ones, i.e. PRNG splitting. - * The given random state is also mutated. + * The given random state is also mutated. * \param rand_state The random state to be forked * \param n The number of forks * \return The forked random states @@ -208,6 +213,15 @@ inline std::vector ForkSeed( return results; } +/*! + * \brief Get deep copy of an IRModule. + * \param mod The IRModule to make a deep copy. + * \return The deep copy of the IRModule. + */ +inline IRModule DeepCopyIRModule(IRModule mod) { + return Downcast(LoadJSON(SaveJSON(mod))); +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 6801eb2acf94..54760abbe521 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -292,6 +292,24 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { return CreateRV(tir::GetLoops(this->GetSRef(block_rv))); } +Array ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(block_rv), false)); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +Array ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array result; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::GetChildBlocks(state_, this->GetSRef(loop_rv), false)); + TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { @@ -445,6 +463,30 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff return CreateRV(result); } +/******** Schedule: Data movement ********/ + +BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + +BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index, + storage_scope); + TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: Compute location ********/ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 9dd3626729db..ffab69c02a9a 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -72,6 +72,7 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline bool HasBlock(const BlockRV& block_rv) const final; inline Array GetSRefs(const Array& rvs) const; inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } @@ -88,6 +89,8 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; + Array GetChildBlocks(const BlockRV& block_rv) override; + Array GetChildBlocks(const LoopRV& loop_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; @@ -102,6 +105,11 @@ class ConcreteScheduleNode : public ScheduleNode { const String& storage_scope) override; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) override; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) override; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) override; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, @@ -188,6 +196,19 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { return this->analyzer_->Simplify(transformed); } +inline bool ConcreteScheduleNode::HasBlock(const BlockRV& block_rv) const { + auto it = this->symbol_table_.find(block_rv); + if (it == this->symbol_table_.end()) { + return false; + } + const ObjectRef& obj = (*it).second; + const auto* sref = obj.as(); + if (sref == nullptr || sref->stmt == nullptr) { + return false; + } + return true; +} + inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { auto it = this->symbol_table_.find(block_rv); if (it == this->symbol_table_.end()) { diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f2da7e24f409..9d6d63293145 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -98,6 +98,15 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); +/*! + * \brief Get the leaf blocks of a specific block/loop + * \param self The schedule state + * \param parent_sref The query block/loop + * \param inclusive Whether to include parent_sref + * \return A list of leaf blocks inside a specific block/loop + */ +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref, + bool inclusive = false); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: @@ -203,6 +212,15 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r */ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope); + +/******** Schedule: Data movement ********/ + +TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope); + +TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope); + /******** Schedule: Compute location ********/ /*! * \brief Move a producer block under the specific loop, and regenerate the diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 8b32a9c14f58..4835c0854cdc 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -55,6 +55,31 @@ Array GetLoops(const StmtSRef& block_sref) { return {result.rbegin(), result.rend()}; } +Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref, + bool inclusive) { + struct Collector : public StmtVisitor { + private: + void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); } + + public: + explicit Collector(const ScheduleState& self) : self(self) {} + + const ScheduleState& self; + Array result; + }; + Collector collector(self); + if (inclusive) { + collector(GetRef(parent_sref->stmt)); + } else if (parent_sref->stmt->IsInstance()) { + const auto* loop = static_cast(parent_sref->stmt); + collector(loop->body); + } else if (parent_sref->stmt->IsInstance()) { + const auto* block = static_cast(parent_sref->stmt); + collector(block->body); + } + return std::move(collector.result); +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -106,8 +131,39 @@ struct GetLoopsTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetChildBlocksTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetChildBlocks"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) { + if (const auto* block = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(block)); + } + if (const auto* loop = block_or_loop_rv.as()) { + return sch->GetChildBlocks(GetRef(loop)); + } + LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey(); + throw; + } + + static String UnpackedAsPython(Array outputs, String block_or_loop_rv) { + PythonAPICall py("get_child_blocks"); + py.Input("", block_or_loop_rv); + py.OutputList(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/read_write_at.cc b/src/tir/schedule/primitive/read_write_at.cc new file mode 100644 index 000000000000..cb693c77cdc5 --- /dev/null +++ b/src/tir/schedule/primitive/read_write_at.cc @@ -0,0 +1,425 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../utils.h" +#include "tvm/runtime/memory.h" +#include "tvm/runtime/object.h" +#include "tvm/tir/schedule/block_scope.h" +#include "tvm/tir/stmt_functor.h" + +namespace tvm { +namespace tir { + +using support::NDIntSet; + +bool HasBuffer(const Array& buffer_regions, const Buffer& buffer) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + return true; + } + } + return false; +} + +void RelaxBufferRegions(const Array& buffer_regions, + const Buffer& buffer, // + const Map& var_dom, // + const Map& bindings, // + std::vector* relaxed_regions) { + for (const BufferRegion& buffer_region : buffer_regions) { + if (buffer_region->buffer.same_as(buffer)) { + Array relaxed_region = + arith::EvalSet(Substitute(buffer_region->region, bindings), var_dom); + relaxed_regions->push_back({relaxed_region.begin(), relaxed_region.end()}); + } + } +} + +class ScopeReplacer : public StmtMutator { + public: + static Block Replace(const BlockNode* scope_block, const Buffer& dst, const ForNode* old_loop, + const ForNode* new_loop) { + ObjectPtr new_scope_block = make_object(*scope_block); + new_scope_block->body = ScopeReplacer(old_loop, new_loop)(std::move(new_scope_block->body)); + new_scope_block->alloc_buffers.push_back(dst); + return Block(new_scope_block); + } + + private: + explicit ScopeReplacer(const ForNode* old_loop, const ForNode* new_loop) + : old_loop_(old_loop), new_loop_(new_loop), found_(false) {} + + Stmt VisitStmt(const Stmt& stmt) final { return found_ ? stmt : StmtMutator::VisitStmt(stmt); } + Stmt VisitStmt_(const BlockNode* block) final { return GetRef(block); } + Stmt VisitStmt_(const ForNode* loop) final { + if (loop == old_loop_) { + found_ = true; + return GetRef(new_loop_); + } + return StmtMutator::VisitStmt_(loop); + } + + const ForNode* old_loop_; + const ForNode* new_loop_; + bool found_; +}; + +class BufferReplacer : public StmtExprMutator { + public: + explicit BufferReplacer(const Buffer& src, const Buffer& dst, Map* block_sref_reuse) + : src_(src), dst_(dst), block_sref_reuse_(block_sref_reuse) {} + + private: + Stmt VisitStmt_(const BufferStoreNode* _store) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); + if (store->buffer.same_as(src_)) { + ObjectPtr new_store = make_object(*store.get()); + new_store->buffer = dst_; + return BufferStore(new_store); + } + return store; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + if (load->buffer.same_as(src_)) { + ObjectPtr new_load = make_object(*load.get()); + new_load->buffer = dst_; + return BufferLoad(new_load); + } + return load; + } + + Stmt VisitStmt_(const BlockNode* _block) final { + Block old_block = GetRef(_block); + Block block = Downcast(StmtExprMutator::VisitStmt_(_block)); + ObjectPtr new_block = make_object(*block.get()); + new_block->reads = ReplaceBuffer(new_block->reads, src_, dst_); + new_block->writes = ReplaceBuffer(new_block->writes, src_, dst_); + block_sref_reuse_->Set(old_block, Block(new_block)); + return Block(new_block); + } + + const Buffer& src_; + const Buffer& dst_; + Map* block_sref_reuse_; +}; + +struct ReadWriteAtImpl { + template + static StmtSRef Main(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope, + Map annotations) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Buffer src = + GetNthAccessBuffer(self, GetRef(block), buffer_index, /*is_write=*/!is_read); + Buffer dst = WithScope(src, storage_scope); + ReadWriteAtImpl impl(self, loop_sref, src, dst, annotations); + std::pair new_loop_block = + impl.MakeLoopAndBlock(src->name + "_" + storage_scope); + StmtSRef result_block_sref = + impl.ReplaceScopeBlock(new_loop_block.first.get(), new_loop_block.second->block.get()); + impl.UpdateBlockInfo(result_block_sref); + return result_block_sref; + } + + private: + static Map GetLoopDomain(const StmtSRefNode* loop_sref) { + Map result; + for (const ForNode* loop; (loop = loop_sref->StmtAs()) != nullptr; + loop_sref = loop_sref->parent) { + result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return result; + } + + StmtSRef ReplaceScopeBlock(const ForNode* new_loop, const BlockNode* new_block) { + StmtSRef scope_root_sref = GetScopeRoot(self_, loop_sref_, + /*require_stage_pipeline=*/true, + /*require_subtree_compact_dataflow=*/false); + const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_root_sref); + Block new_scope_block = ScopeReplacer::Replace(scope_block, dst_, loop_, new_loop); + block_sref_reuse_.Set(GetRef(scope_block), new_scope_block); + self_->Replace(scope_root_sref, new_scope_block, block_sref_reuse_); + return self_->stmt2ref.at(new_block); + } + + void UpdateBlockInfo(const StmtSRef& new_block_sref) { + BlockInfo& block_info = self_->block_info[new_block_sref]; + block_info.affine_binding = true; + block_info.region_cover = true; + block_info.scope->stage_pipeline = true; + } + + template + std::pair MakeLoopAndBlock(const String& new_block_name_hint) { + Array subtrees = AsArray(loop_->body); + int n_subtrees = subtrees.size(); + runtime::StorageScope scope = runtime::StorageScope::Create(dst_.scope()); + std::vector relaxed_regions; + std::vector r_pos; + std::vector w_pos; + relaxed_regions.reserve(n_subtrees); + r_pos.reserve(n_subtrees); + w_pos.reserve(n_subtrees); + // Step 1. Iterate over all subtrees + for (int i = 0; i < n_subtrees; ++i) { + bool r_visited = false; + bool w_visited = false; + auto f_visit = [this, &relaxed_regions, &r_visited, &w_visited, + &scope](const ObjectRef& obj) -> bool { + const BlockRealizeNode* realize = obj.as(); + if (realize == nullptr) { + return true; + } + const BlockNode* block = realize->block.get(); + bool has_r = HasBuffer(block->reads, src_); + bool has_w = HasBuffer(block->writes, src_); + r_visited = r_visited || has_r; + w_visited = w_visited || has_w; + if (is_read ? has_r : has_w) { + RelaxBufferRegions( + /*buffer_regions=*/is_read ? block->reads : block->writes, + /*buffer=*/src_, + /*var_dom=*/ + AsIntSet(LoopDomainOfSRefTreePath( + /*low_inclusive=*/GetRef(self_->stmt2ref.at(block)->parent), + /*high_exclusive=*/loop_sref_, + /*extra_relax_scope=*/scope)), + /*bindings=*/GetBindings(GetRef(realize)), + /*relaxed_regions=*/&relaxed_regions); + } + return false; + }; + PreOrderVisit(subtrees[i], f_visit); + if (r_visited) { + r_pos.push_back(i); + } + if (w_visited) { + w_pos.push_back(i); + } + } + // Step 2. Calculate `insert_pos` and [st, ed) for buffer replacement + int insert_pos = -1, st = -1, ed = -1; + if (is_read) { + ICHECK(!r_pos.empty()); + // No write after the first read + ICHECK(w_pos.empty() || w_pos.back() < r_pos.front()); + // Can be inserted at [0, r_pos.front()], i.e. before the first read + insert_pos = r_pos.front(); + // Buffer reads in [insert_pos, +oo) is rewritten + st = insert_pos; + ed = n_subtrees; + } else { + ICHECK(!w_pos.empty()); + // No read after the last write + ICHECK(r_pos.empty() || r_pos.back() <= w_pos.back()); + // Can be inserted into (w_pos.back(), +oo), i.e. after the last write + insert_pos = w_pos.back() + 1; + st = 0; + ed = insert_pos; + } + // Step 3. Calculate `domain`, the domain of buffer access + NDIntSet relaxed = support::NDIntSetUnion(relaxed_regions); + int ndim = relaxed.size(); + Array domain; + domain.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + const arith::IntSet& int_set = relaxed[i]; + PrimExpr min = analyzer_->Simplify(int_set.min()); + PrimExpr extent = analyzer_->Simplify(int_set.max() + 1 - min); + domain.push_back(Range::FromMinExtent(min, extent)); + } + // Step 4. Insert the auto copy block and replace buffers + BufferReplacer replacer(src_, dst_, &block_sref_reuse_); + for (int i = st; i < ed; ++i) { + Stmt stmt = subtrees[i]; + subtrees.Set(i, Stmt(nullptr)); + subtrees.Set(i, replacer(std::move(stmt))); + } + BlockRealize realize = + is_read + ? MakeBlock(src_, dst_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain) + : MakeBlock(dst_, src_, new_block_name_hint, GetLoopDomain(loop_sref_.get()), domain); + subtrees.insert(subtrees.begin() + insert_pos, realize); + ObjectPtr new_loop = make_object(*loop_); + new_loop->body = SeqStmt(std::move(subtrees)); + return {For(new_loop), realize}; + } + + BlockRealize MakeBlock(const Buffer& copy_from, const Buffer& copy_to, const String& name_hint, + const Map& loop_domain, Array domain) const { + int n = domain.size(); + std::vector loop_vars; + loop_vars.reserve(n); + for (int i = 0; i < n; ++i) { + loop_vars.push_back(Var("ax" + std::to_string(i))); + } + Map bindings; + Array iter_vars; + Array iter_values; + Array indices; + iter_vars.reserve(n); + iter_values.reserve(n); + indices.reserve(n); + for (int i = 0; i < n; ++i) { + auto f_substitute = [&loop_domain, &bindings, &iter_vars, + &iter_values](const Var& var) -> Optional { + auto it = bindings.find(var); + if (it != bindings.end()) { + return (*it).second; + } + Range range = loop_domain.at(var); + ObjectPtr v = make_object(*var.get()); + v->name_hint = "v" + std::to_string(iter_vars.size()); + bindings.Set(var, Var(v)); + iter_values.push_back(var); + iter_vars.push_back(IterVar(range, Var(v), IterVarType::kDataPar)); + return Var(v); + }; + ObjectPtr dom = make_object(*domain[i].get()); + dom->min = Substitute(std::move(dom->min), f_substitute); + dom->extent = Substitute(std::move(dom->extent), f_substitute); + domain.Set(i, Range(dom)); + } + for (int i = 0; i < n; ++i) { + indices.push_back(domain[i]->min + loop_vars[i]); + } + Stmt stmt = BufferStore(copy_to, /*value=*/BufferLoad(copy_from, indices), /*indices=*/indices); + for (int i = n - 1; i >= 0; --i) { + stmt = For(loop_vars[i], Integer(0), domain[i]->extent, ForKind::kSerial, stmt); + } + return BlockRealize( + /*values=*/iter_values, + /*predicate=*/const_true(), + Block(/*iter_vars=*/iter_vars, + /*reads=*/{BufferRegion(copy_from, domain)}, + /*writes=*/{BufferRegion(copy_to, domain)}, + /*name_hint=*/name_hint, // + /*body=*/std::move(stmt), + /*init=*/NullOpt, + /*alloc_buffers=*/{}, + /*match_buffers=*/{}, + /*annotations=*/annotations_)); + } + + explicit ReadWriteAtImpl(ScheduleState self, const StmtSRef& loop_sref, const Buffer& src, + const Buffer& dst, Map annotations) + : self_(self), + loop_sref_(loop_sref), + loop_(nullptr), + src_(src), + dst_(dst), + annotations_(annotations), + block_sref_reuse_(), + analyzer_(std::make_unique()) { + loop_ = TVM_SREF_TO_FOR(loop_, loop_sref); + } + + ScheduleState self_; + const StmtSRef& loop_sref_; + const ForNode* loop_; + const Buffer& src_; + const Buffer& dst_; + Map annotations_; + Map block_sref_reuse_; + std::unique_ptr analyzer_; +}; + +StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int read_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, read_buffer_index, storage_scope, + {{"auto_copy", Integer(1)}}); +} + +StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int write_buffer_index, const String& storage_scope) { + return ReadWriteAtImpl::Main(self, loop_sref, block_sref, write_buffer_index, + storage_scope, {{"auto_copy", Integer(1)}}); +} + +/******** Instruction Registration ********/ + +struct ReadAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "ReadAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref, + int buffer_index, const String& storage_scope); + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer read_buffer_index, String storage_scope) { + return sch->ReadAt(loop, block, read_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer read_buffer_index, String storage_scope) { + PythonAPICall py("read_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("read_buffer_index", read_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +struct WriteAtTraits : public UnpackedInstTraits { + static constexpr const char* kName = "WriteAt"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop, BlockRV block, + Integer write_buffer_index, String storage_scope) { + return sch->WriteAt(loop, block, write_buffer_index->value, storage_scope); + } + + static String UnpackedAsPython(Array outputs, String loop, String block, + Integer write_buffer_index, String storage_scope) { + PythonAPICall py("write_at"); + py.Input("loop", loop); + py.Input("block", block); + py.Input("write_buffer_index", write_buffer_index->value); + py.Input("storage_scope", storage_scope); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(ReadAtTraits); +TVM_REGISTER_INST_KIND_TRAITS(WriteAtTraits); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 171838572dbb..4acf61860112 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -20,6 +20,7 @@ #include #include "../utils.h" +#include "tvm/support/random_engine.h" namespace tvm { namespace tir { diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 4ef456a9527b..8f7caa914530 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -130,6 +130,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks") + .set_body_typed([](Schedule self, ObjectRef rv) { + if (const auto* block_rv = rv.as()) { + return self->GetChildBlocks(GetRef(block_rv)); + } + if (const auto* loop_rv = rv.as()) { + return self->GetChildBlocks(GetRef(loop_rv)); + } + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + }); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); @@ -147,6 +159,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheRead") .set_body_method(&ScheduleNode::CacheRead); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCacheWrite") .set_body_method(&ScheduleNode::CacheWrite); +/******** (FFI) Data movement ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReadAt").set_body_method(&ScheduleNode::ReadAt); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWriteAt") + .set_body_method(&ScheduleNode::WriteAt); /******** (FFI) Compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") .set_body_method(&ScheduleNode::ComputeAt); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d1e103ce2d01..94f15d5c6543 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -97,6 +97,28 @@ Array TracedScheduleNode::GetLoops(const BlockRV& block_rv) { return results; } +Array TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + +Array TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) { + Array results = ConcreteScheduleNode::GetChildBlocks(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Fuse(const Array& loop_rvs) { @@ -206,6 +228,31 @@ BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer return result; } +BlockRV TracedScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int read_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::ReadAt(loop_rv, block_rv, read_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("ReadAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(read_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} + +BlockRV TracedScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, + int write_buffer_index, const String& storage_scope) { + BlockRV result = + ConcreteScheduleNode::WriteAt(loop_rv, block_rv, write_buffer_index, storage_scope); + + static const InstructionKind& kind = InstructionKind::Get("WriteAt"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv, block_rv}, + /*attrs=*/{Integer(write_buffer_index), storage_scope}, + /*outputs=*/{result})); + return result; +} /******** Schedule: Compute location ********/ void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 81e0fe84d517..d5676f4cdce7 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -54,6 +54,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; Array GetLoops(const BlockRV& block_rv) final; + Array GetChildBlocks(const BlockRV& block_rv) final; + Array GetChildBlocks(const LoopRV& loop_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; @@ -68,6 +70,11 @@ class TracedScheduleNode : public ConcreteScheduleNode { const String& storage_scope) final; BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope) final; + /******** Schedule: Data movement ********/ + BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index, + const String& storage_scope) final; + BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index, + const String& storage_scope) final; /******** Schedule: Compute location ********/ void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final; void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, diff --git a/tests/python/unittest/test_meta_schedule_measure_callback.py b/tests/python/unittest/test_meta_schedule_measure_callback.py new file mode 100644 index 000000000000..e7217ea0007e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_measure_callback.py @@ -0,0 +1,135 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import re +from typing import List + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.meta_schedule.runner.runner import Runner +from tvm.meta_schedule.task_scheduler.task_scheduler import TaskScheduler +from tvm.meta_schedule.tune_context import TuneContext +from tvm.script import tir as T + +from tvm.meta_schedule.measure_callback import PyMeasureCallback +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.builder import BuilderResult +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_measure_callback(): + class FancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + assert len(measure_candidates) == 1 + assert_structural_equal(measure_candidates[0].sch.mod, Matmul) + assert ( + len(builds) == 1 + and builds[0].error_msg is None + and builds[0].artifact_path == "test_build" + ) + assert ( + len(results) == 1 and results[0].error_msg is None and len(results[0].run_secs) == 2 + ) + return True + + measure_callback = FancyMeasureCallback() + assert measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(Schedule(Matmul), None)], + [BuilderResult("test_build", None)], + [RunnerResult([1.0, 2.1], None)], + ) + + +def test_meta_schedule_measure_callback_fail(): + class FailingMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: TaskScheduler, + tasks: List[TuneContext], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + return False + + measure_callback = FailingMeasureCallback() + assert not measure_callback.apply( + TaskScheduler(), + [], + [MeasureCandidate(None, None)], + [BuilderResult(None, None)], + [RunnerResult(None, None)], + ) + + +def test_meta_schedule_measure_callback_as_string(): + class NotSoFancyMeasureCallback(PyMeasureCallback): + def apply( + self, + task_scheduler: "TaskScheduler", + tasks: List["TuneContext"], + measure_candidates: List[MeasureCandidate], + builds: List[BuilderResult], + results: List[RunnerResult], + ) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyMeasureCallback({_get_hex_address(self.handle)})" + + measure_callback = NotSoFancyMeasureCallback() + pattern = re.compile(r"NotSoFancyMeasureCallback\(0x[a-f|0-9]*\)") + assert pattern.match(str(measure_callback)) + + +if __name__ == "__main__": + test_meta_schedule_measure_callback() + test_meta_schedule_measure_callback_fail() + test_meta_schedule_measure_callback_as_string() diff --git a/tests/python/unittest/test_meta_schedule_mutator.py b/tests/python/unittest/test_meta_schedule_mutator.py new file mode 100644 index 000000000000..b4d94dc9a8e3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_mutator.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List, Optional + +import re + +import tvm +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T + +from tvm.meta_schedule.mutator import PyMutator +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import Schedule, Trace + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_meta_schedule_mutator(): + class FancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + return Trace(trace.insts, {}) + + mutator = FancyMutator() + sch = Schedule(Matmul) + res = mutator.apply(sch.trace) + assert res is not None + new_sch = sch.copy() + res.apply_to_schedule(new_sch, remove_postproc=True) + assert_structural_equal(sch.mod, new_sch.mod) + + +def test_meta_schedule_mutator_as_string(): + class YetAnotherFancyMutator(PyMutator): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, trace: Trace) -> Optional[Trace]: + pass + + def __str__(self) -> str: + return f"YetAnotherFancyMutator({_get_hex_address(self.handle)})" + + mutator = YetAnotherFancyMutator() + pattern = re.compile(r"YetAnotherFancyMutator\(0x[a-f|0-9]*\)") + assert pattern.match(str(mutator)) + + +if __name__ == "__main__": + test_meta_schedule_mutator() + test_meta_schedule_mutator_as_string() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py new file mode 100644 index 000000000000..95b5ed002b27 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List +import pytest +import math +import sys + +import tvm +from tvm._ffi.base import TVMError, py2cerror +from tvm.ir.base import assert_structural_equal +from tvm.script import tir as T +from tvm.tir.schedule import Schedule, BlockRV, block_scope +from tvm.target import Target + +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.utils import _get_hex_address +from tvm.tir.schedule import trace +from tvm.tir.schedule.trace import Trace + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DuplicateMatmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class TrinityMatmul: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.alloc_buffer((1024, 1024), "float32") + D = T.match_buffer(d, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 + for i, j in T.grid(1024, 1024): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = C[vi, vj] * 5.0 + + +@tvm.script.ir_module +class TrinityMatmulProcessedForReference: + @T.prim_func + def main(a: T.handle, d: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, [1024, 1024], dtype="float32") + D = T.match_buffer(d, [1024, 1024], dtype="float32") + # body + # with tir.block("root") + B = T.alloc_buffer([1024, 1024], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("A"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([A[vi, vj]]) + T.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * T.float32(2) + for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): + with T.block("C"): + vi = T.axis.S(1024, i0_0 * 64 + i0_1) + vj = T.axis.S(1024, i1_0 * 16 + i1_1) + T.reads([B[vi, vj]]) + T.writes([D[vi, vj]]) + D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +class WowSoFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [new_sch] + + +class DoubleScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result = [new_sch] + new_sch = sch.copy() + i, j, k = new_sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = new_sch.split(loop=i, factors=[4, 64, 2, 2]) + j_0, j_1, j_2, j_3 = new_sch.split(loop=j, factors=[2, 4, 64, 2]) + k_0, k_1 = new_sch.split(loop=k, factors=[32, 32]) + new_sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + result.append(new_sch) + return result + + +class ReorderScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, i_0, j_0) + result = [new_sch] + new_sch = sch.copy() + i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3 = new_sch.get_loops(block=block) + new_sch.reorder(i_1, j_3, i_0, j_0, j_1, k_0, i_2, j_2, k_1, i_3) + result.append(new_sch) + return result + + +def test_meta_schedule_post_order_apply(): + mod = Matmul + context = TuneContext( + mod=mod, target=Target("llvm"), task_name="Test Task", sch_rules=[WowSoFancyScheduleRule()] + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 1 + try: + tvm.ir.assert_structural_equal(mod, schs[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(schs[0]) + + +def test_meta_schedule_post_order_apply_double(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 2 + for sch in schs: + try: + tvm.ir.assert_structural_equal(mod, sch.mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_multiple(): + mod = Matmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Double Rules Task", + sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + try: + tvm.ir.assert_structural_equal(mod, sch.mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_post_order_apply_duplicate_matmul(): + mod = DuplicateMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Duplicate Matmul Task", + sch_rules=[WowSoFancyScheduleRule()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises( + TVMError, + match=r".*TVMError: Check failed: \(block_names_.count\(block->name_hint\) == 0\)" + r" is false: Duplicated block name matmul in function main not supported!", + ): + post_order_apply.generate_design_space(mod) + + +def test_meta_schedule_post_order_apply_remove_block(): + class TrinityDouble(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + class RemoveBlock(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + sch = sch.copy() + if sch.get(block).name_hint == "B": + sch.compute_inline(block) + return [sch] + + def correct_trace(a, b, c, d): + return "\n".join( + [ + 'b0 = sch.get_block(name="A", func_name="main")', + 'b1 = sch.get_block(name="B", func_name="main")', + 'b2 = sch.get_block(name="C", func_name="main")', + "sch.compute_inline(block=b1)", + 'b3 = sch.get_block(name="A", func_name="main")', + 'b4 = sch.get_block(name="C", func_name="main")', + "l5, l6 = sch.get_loops(block=b4)", + "l7, l8 = sch.split(loop=l5, factors=" + str(a) + ")", + "l9, l10 = sch.split(loop=l6, factors=" + str(b) + ")", + "sch.reorder(l7, l9, l8, l10)", + "l11, l12 = sch.get_loops(block=b3)", + "l13, l14 = sch.split(loop=l11, factors=" + str(c) + ")", + "l15, l16 = sch.split(loop=l12, factors=" + str(d) + ")", + "sch.reorder(l13, l15, l14, l16)", + ] + ) + + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Remove Block Task", + sch_rules=[RemoveBlock(), TrinityDouble()], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + schs = post_order_apply.generate_design_space(mod) + assert len(schs) == 4 + for sch in schs: + with pytest.raises( + tvm.tir.schedule.schedule.ScheduleError, + match="ScheduleError: An error occurred in the schedule primitive 'get-block'.", + ): + sch.get_block("B", "main") + assert ( + str(sch.trace) == correct_trace([16, 64], [64, 16], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [2, 512], [2, 512]) + or str(sch.trace) == correct_trace([16, 64], [64, 16], [16, 64], [64, 16]) + or str(sch.trace) == correct_trace([2, 512], [2, 512], [16, 64], [64, 16]) + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py new file mode 100644 index 000000000000..52f07fdff099 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import math +import re + +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.postproc import PyPostproc +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def schedule_matmul(sch: Schedule): + block = sch.get_block("matmul") + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + + +def test_meta_schedule_postproc(): + class FancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + schedule_matmul(sch) + return True + + postproc = FancyPostproc() + mod = Matmul + sch = Schedule(mod) + assert postproc.apply(sch) + try: + tvm.ir.assert_structural_equal(sch.mod, mod) + raise Exception("The post processing did not change the schedule.") + except (ValueError): + _check_correct(sch) + + +def test_meta_schedule_postproc_fail(): + class FailingPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + return False + + postproc = FailingPostproc() + sch = Schedule(Matmul) + assert not postproc.apply(sch) + + +def test_meta_schedule_postproc_as_string(): + class NotSoFancyPostproc(PyPostproc): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule) -> bool: + pass + + def __str__(self) -> str: + return f"NotSoFancyPostproc({_get_hex_address(self.handle)})" + + postproc = NotSoFancyPostproc() + pattern = re.compile(r"NotSoFancyPostproc\(0x[a-f|0-9]*\)") + assert pattern.match(str(postproc)) + + +if __name__ == "__main__": + test_meta_schedule_postproc() + test_meta_schedule_postproc_fail() + test_meta_schedule_postproc_as_string() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule.py new file mode 100644 index 000000000000..e79ca69ca64d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from typing import List + +import math +import re + +import tvm +from tvm.script import tir as T + +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.utils import _get_hex_address + +from tvm.tir.schedule import Schedule, BlockRV + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def _check_correct(schedule: Schedule): + trace = schedule.trace + for inst in trace.decisions: + assert math.prod(trace.decisions[inst]) == 1024 + + +def test_meta_schedule_schedule_rule(): + class FancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + i, j, k = sch.get_loops(block=block) + i_0, i_1, i_2, i_3 = sch.split(loop=i, factors=[2, 4, 64, 2]) + j_0, j_1, j_2, j_3 = sch.split(loop=j, factors=[4, 64, 2, 2]) + k_0, k_1 = sch.split(loop=k, factors=[32, 32]) + sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) + return [sch] + + sch_rule = FancyScheduleRule() + mod = Matmul + sch = Schedule(mod) + res = sch_rule.apply(sch, block=sch.get_block("matmul")) + assert len(res) == 1 + try: + tvm.ir.assert_structural_equal(mod, res[0].mod) + raise Exception("The schedule rule did not change the schedule.") + except (ValueError): + _check_correct(res[0]) + + +def test_meta_schedule_schedule_rule_as_string(): + class YetStillSomeFancyScheduleRule(PyScheduleRule): + def initialize_with_tune_context(self, tune_context: "TuneContext") -> None: + pass + + def apply(self, schedule: Schedule, block: BlockRV) -> List[Schedule]: + pass + + def __str__(self) -> str: + return f"YetStillSomeFancyScheduleRule({_get_hex_address(self.handle)})" + + sch_rule = YetStillSomeFancyScheduleRule() + pattern = re.compile(r"YetStillSomeFancyScheduleRule\(0x[a-f|0-9]*\)") + assert pattern.match(str(sch_rule)) + + +if __name__ == "__main__": + test_meta_schedule_schedule_rule() + test_meta_schedule_schedule_rule_as_string() diff --git a/tests/python/unittest/test_meta_schedule_search_strategy.py b/tests/python/unittest/test_meta_schedule_search_strategy.py index 9b3ddfd7c789..f940d11b79e4 100644 --- a/tests/python/unittest/test_meta_schedule_search_strategy.py +++ b/tests/python/unittest/test_meta_schedule_search_strategy.py @@ -26,7 +26,7 @@ from tvm.meta_schedule import TuneContext from tvm.meta_schedule.runner import RunnerResult from tvm.meta_schedule.space_generator import ScheduleFn -from tvm.meta_schedule.search_strategy import ReplayTrace +from tvm.meta_schedule.search_strategy import SearchStrategy, ReplayTrace, ReplayFunc from tvm.script import tir as T from tvm.tir.schedule import Schedule, Trace @@ -56,9 +56,13 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument -def _is_trace_equal(sch_1: Schedule, sch_2: Schedule) -> bool: - trace_1 = Trace(sch_1.trace.insts, {}) - trace_2 = Trace(sch_2.trace.insts, {}) +def _is_trace_equal(sch_1: Schedule, sch_2: Schedule, remove_decisions=True) -> bool: + if remove_decisions: + trace_1 = Trace(sch_1.trace.insts, {}) + trace_2 = Trace(sch_2.trace.insts, {}) + else: + trace_1 = sch_1.trace + trace_2 = sch_2.trace return str(trace_1) == str(trace_2) @@ -72,29 +76,35 @@ def _schedule_matmul(sch: Schedule): sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3) -def test_meta_schedule_replay_trace(): +@pytest.mark.parametrize("TestClass", [ReplayFunc, ReplayTrace]) +def test_meta_schedule_replay_func(TestClass: SearchStrategy): num_trials_per_iter = 7 num_trials_total = 20 - (example_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) - replay = ReplayTrace(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) - tune_context = TuneContext(mod=Matmul) - replay.initialize_with_tune_context(tune_context) - - num_trials_each_round: List[int] = [] - replay.pre_tuning([example_sch]) - while True: - candidates = replay.generate_measure_candidates() - if candidates is None: - break - num_trials_each_round.append(len(candidates)) + strategy = TestClass(num_trials_per_iter=num_trials_per_iter, num_trials_total=num_trials_total) + tune_context = TuneContext(mod=Matmul, space_generator=ScheduleFn(sch_fn=_schedule_matmul)) + tune_context.space_generator.initialize_with_tune_context(tune_context) + spaces = tune_context.space_generator.generate_design_space(tune_context.mod) + + strategy.initialize_with_tune_context(tune_context) + strategy.pre_tuning(spaces) + (correct_sch,) = ScheduleFn(sch_fn=_schedule_matmul).generate_design_space(Matmul) + num_trials_each_iter: List[int] = [] + candidates = strategy.generate_measure_candidates() + while candidates is not None: + num_trials_each_iter.append(len(candidates)) runner_results: List[RunnerResult] = [] for candidate in candidates: - assert _is_trace_equal(candidate.sch, example_sch) - runner_results.append(RunnerResult(run_secs=[0.5, 0.4, 0.3], error_msg=None)) - replay.notify_runner_results(runner_results) - replay.post_tuning() - assert num_trials_each_round == [7, 7, 6] + _is_trace_equal( + candidate.sch, + correct_sch, + remove_decisions=(type(strategy) == ReplayTrace), + ) + runner_results.append(RunnerResult(run_secs=[0.11, 0.41, 0.54], error_msg=None)) + strategy.notify_runner_results(runner_results) + candidates = strategy.generate_measure_candidates() + strategy.post_tuning() + assert num_trials_each_iter == [7, 7, 6] if __name__ == "__main__": diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 49a3f6309183..3eb050db3baa 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion diff --git a/tests/python/unittest/test_meta_schedule_task_extraction.py b/tests/python/unittest/test_meta_schedule_task_extraction.py new file mode 100644 index 000000000000..8d1eca51432e --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_task_extraction.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import sys +from typing import Tuple + +import pytest + +import tvm +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing import MODEL_TYPE, MODEL_TYPES, get_torch_model + + +@pytest.mark.skip("Skip because it runs too slowly as a unittest") +@pytest.mark.parametrize( + "model_name", + [ + # Image classification + "resnet50", + "alexnet", + "vgg16", + "squeezenet1_0", + "densenet121", + "densenet161", + "densenet169", + "densenet201", + "inception_v3", + "googlenet", + "shufflenet_v2_x1_0", + "mobilenet_v2", + "mobilenet_v3_large", + "mobilenet_v3_small", + "resnext50_32x4d", + "wide_resnet50_2", + "mnasnet1_0", + # Segmentation + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", + # Object detection + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "maskrcnn_resnet50_fpn", + # video classification + "r3d_18", + "mc3_18", + "r2plus1d_18", + ], +) +@pytest.mark.parametrize("batch_size", [1, 8, 16]) +@pytest.mark.parametrize("target", ["llvm", "cuda"]) +def test_meta_schedule_extract_from_torch_model(model_name: str, batch_size: int, target: str): + if model_name == "inception_v3" and batch_size == 1: + pytest.skip("inception_v3 does not handle batch_size of 1") + + input_shape: Tuple[int, ...] + if MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.SEGMENTATION: + input_shape = (batch_size, 3, 299, 299) + elif MODEL_TYPES[model_name] == MODEL_TYPE.OBJECT_DETECTION: + input_shape = (1, 3, 300, 300) + elif MODEL_TYPES[model_name] == MODEL_TYPE.VIDEO_CLASSIFICATION: + input_shape = (batch_size, 3, 3, 299, 299) + else: + raise ValueError("Unsupported model: " + model_name) + + output_shape: Tuple[int, int] = (batch_size, 1000) + mod, params = get_torch_model( + model_name=model_name, + input_shape=input_shape, + output_shape=output_shape, + dtype="float32", + ) + target = tvm.target.Target(target) + ms.integration.extract_task(mod, params=params, target=target) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_task_scheduler.py b/tests/python/unittest/test_meta_schedule_task_scheduler.py index f1336b90c3fe..453c2d18f6dc 100644 --- a/tests/python/unittest/test_meta_schedule_task_scheduler.py +++ b/tests/python/unittest/test_meta_schedule_task_scheduler.py @@ -24,9 +24,10 @@ import pytest import tvm +from tvm._ffi.base import TVMError from tvm.script import tir as T from tvm.ir import IRModule -from tvm.tir import Schedule +from tvm.tir import Schedule, schedule from tvm.meta_schedule import TuneContext from tvm.meta_schedule.space_generator import ScheduleFn from tvm.meta_schedule.search_strategy import ReplayTrace diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py b/tests/python/unittest/test_tir_schedule_read_write_at.py new file mode 100644 index 000000000000..79a7aad10f25 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_read_write_at.py @@ -0,0 +1,221 @@ +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + + +# fmt: off +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable + +@T.prim_func +def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=undefined-loop-variable + A = T.match_buffer(a, [2048, 2048], "float32") + B = T.match_buffer(b, [2048, 2048], "float32") + C = T.match_buffer(c, [2048, 2048], "float32") + for by in T.thread_binding(0, 32, thread = "blockIdx.y"): + for bx in T.thread_binding(0, 32, thread = "blockIdx.x"): + for vy in T.thread_binding(0, 2, thread = "vthread.y"): + for vx in T.thread_binding(0, 2, thread = "vthread.x"): + for ty in T.thread_binding(0, 8, thread = "threadIdx.y"): + for tx in T.thread_binding(0, 8, thread = "threadIdx.x"): + for k0 in T.serial(0, 256): + for k1 in T.unroll(0, 8): + for _, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B[vk, vj] + + +@T.prim_func +def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + +@T.prim_func +def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [2048, 2048], dtype="float32") + B = T.match_buffer(b, [2048, 2048], dtype="float32") + C = T.match_buffer(c, [2048, 2048], dtype="float32") + A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared") + for by in T.thread_binding(0, 32, thread="blockIdx.y"): + for bx in T.thread_binding(0, 32, thread="blockIdx.x"): + for vy in T.thread_binding(0, 2, thread="vthread.y"): + for vx in T.thread_binding(0, 2, thread="vthread.x"): + for ty in T.thread_binding(0, 8, thread="threadIdx.y"): + for tx in T.thread_binding(0, 8, thread="threadIdx.x"): + for k0 in T.serial(0, 256): + with T.block("A_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(256, k0) + T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.writes([A_shared[v0 * 64 : v0 * 64 + 64, v1 * 8 : v1 * 8 + 8]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 8): + A_shared[v0 * 64 + ax0, v1 * 8 + ax1] = A[v0 * 64 + ax0, v1 * 8 + ax1] + with T.block("B_shared"): + v0 = T.axis.S(256, k0) + v1 = T.axis.S(32, bx) + T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(8, 64): + B_shared[v0 * 8 + ax0, v1 * 64 + ax1] = B[v0 * 8 + ax0, v1 * 64 + ax1] + for k1 in T.unroll(0, 8): + for v_, i, j in T.grid(1, 4, 4): + with T.block("C"): + vi = T.axis.S(2048, by * 64 + vy * 32 + ty * 4 + i) + vj = T.axis.S(2048, bx * 64 + vx * 32 + tx * 4 + j) + vk = T.axis.R(2048, k0 * 8 + k1) + T.reads([C_shared[vi, vj], A_shared[vi, vk], B_shared[vk, vj]]) + T.writes([C_shared[vi, vj]]) + with T.init(): + C_shared[vi, vj] = T.float32(0) + C_shared[vi, vj] = C_shared[vi, vj] + A_shared[vi, vk] * B_shared[vk, vj] + with T.block("C_shared"): + v0 = T.axis.S(32, by) + v1 = T.axis.S(32, bx) + T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : v1 * 64 + 64]]) + T.block_attr({"auto_copy":1}) + for ax0, ax1 in T.grid(64, 64): + C[v0 * 64 + ax0, v1 * 64 + ax1] = C_shared[v0 * 64 + ax0, v1 * 64 + ax1] + + +# pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable +# fmt: on + + +def test_read_at_global_to_shared_a(): + sch = tir.Schedule(cuda_matmul, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 1, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a) + verify_trace_roundtrip(sch, cuda_matmul) + + +def test_read_at_global_to_shared_ab(): + sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.read_at(k0, block, 2, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab) + verify_trace_roundtrip(sch, cuda_matmul_read_at_a) + + +def test_read_at_local_to_shared_c(): + sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all") + block = sch.get_block("C") + # pylint: disable=invalid-name + _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block) + # pylint: enable=invalid-name + sch.write_at(tx, block, 0, "shared") + tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c) + verify_trace_roundtrip(sch, cuda_matmul_read_at_ab) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 440d0ab67a50..1596d08a1fb4 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -142,5 +142,22 @@ def test_tir_schedule_remove_rv(): sch.get(block_rv) +def test_get_child_blocks(): + s = tir.Schedule(matmul, debug_mask="all") + init = s.get_block("init") + update = s.get_block("update") + # loop + blocks = s.get_child_blocks(s.get_loops(init)[0]) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + # block + root = s.get_block("root") + blocks = s.get_child_blocks(root) + assert len(blocks) == 2 + assert s.get(init) == s.get(blocks[0]) + assert s.get(update) == s.get(blocks[1]) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))