diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index f07d8e136644..307ec309c009 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -237,6 +237,7 @@ class PyDatabaseNode : public DatabaseNode { // PackedFuncs are all not visited, because the reflection system doesn't take care of them, // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. + // // `f_has_workload` is not visited // `f_commit_workload` is not visited // `f_commit_tuning_record` is not visited diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 7a7599b0a4f8..ff3a14c076e4 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -53,7 +53,7 @@ class TuneContextNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Map mutator_probs; /*! \brief The name of the tuning task. */ - Optional task_name; + String task_name; /*! \brief The random state. */ support::LinearCongruentialEngine::TRandState rand_state; /*! \brief The number of threads to be used. */ diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 89871f0d6352..49555e8e37f1 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -500,14 +500,14 @@ class ScheduleNode : public runtime::Object { /******** Schedule: Annotation ********/ /*! * \brief Annotate a loop with a key value pair - * \param loop_rv The loop to be annotated + * \param loop The loop to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0; /*! * \brief Annotate a block with a key value pair - * \param block_rv The block to be annotated + * \param loop The block to be annotated * \param ann_key The annotation key * \param ann_val The annotation value, a string or a ExprRV */ @@ -515,13 +515,13 @@ class ScheduleNode : public runtime::Object { const ObjectRef& ann_val) = 0; /*! * \brief Unannotate a loop's annotation with key ann_key - * \param loop_rv The loop to be unannotated + * \param loop The loop to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0; /*! * \brief Unannotate a block's annotation with key ann_key - * \param block_rv The block to be unannotated + * \param loop The block to be unannotated * \param ann_key The annotation key */ virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 7b07146f446c..4074f5203857 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1442,6 +1442,93 @@ constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline */ constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; +/*! + * \brief Mark that the block need to add predicate for block var bounds during lowering + */ +constexpr const char* require_block_var_bound_predicate = "require_bound_predicate"; + +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; + +/*! + * \brief Mark that the block should be further rewritten using tensorization. + */ +constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensorize"; + +/*! \brief Mark that tensor core is enabled in the PrimExpr */ +constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_low_inclusive = + "meta_schedule.thread_extent_low_inclusive"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_high_inclusive = + "meta_schedule.thread_extent_high_inclusive"; + +/*! + * \brief Mark a block as generated by cache_read or cache_write block. + * 0 means cache_read; 1 means cache_write. + * \sa meta_schedule_cache_type_read + * \sa meta_schedule_cache_type_write + */ +constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type"; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_read = 0; + +/*! \sa meta_schedule_cache_type */ +constexpr const int meta_schedule_cache_type_write = 1; + +/*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ +constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; + +/*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ +constexpr const char* meta_schedule_random_compute_producer = + "meta_schedule.random_compute_producer"; + +/*! \brief Mark auto-parallel setting on the block. */ +constexpr const char* meta_schedule_parallel = "meta_schedule.parallel"; + +/*! \brief Mark auto-vectorize setting on the block. */ +constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; + +/*! \brief Pragma: auto-unroll, max_step */ +constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"; + +/*! \brief Pragma: unroll explicit */ +constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit"; + +/*! \brief Mark the scope of the software pipeline */ +constexpr const char* software_pipeline_scope = "software_pipeline_scope"; + +/*! \brief Mark the stage of a statement in the software pipeline */ +constexpr const char* software_pipeline_stage = "software_pipeline_stage"; + +/*! \brief Mark the order of a statement in the software pipeline */ +constexpr const char* software_pipeline_order = "software_pipeline_order"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_stage = "nested_software_pipeline_stage"; + +/*! \brief Mark the stage of the result of the software pipeline lowering. This is used to specify + * the behavior of nested software pipelines. Should be a 3-tuple consisting of the stage of the + * prologue, the body, and the epilogue of the software pipeline. + */ +constexpr const char* nested_software_pipeline_order = "nested_software_pipeline_order"; + /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 4df54f0208d3..6b8edf29bf2c 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -383,6 +383,20 @@ TVM_DLL Pass LowerInitBlock(); */ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + +/*! + * \brief Narrow the extents of some loops by checking whether some constraints in the block iter + * bound predicates can be directly applied on the loops. + * \return The pass. + */ +TVM_DLL Pass ApplyBlockBoundPredicate(); + /*! * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the * corresponding iter_values in BlockRealize, for opaque blocks by removing all diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index f1156998bdac..0e9c4abebbe1 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -543,7 +543,8 @@ def print_best(self, log_file, print_mode="schedule"): code: str The best schedule code in python API or CUDA source code """ - inp, _ = load_best_record(log_file, self.workload_key) + inp, res = load_best_record(log_file, self.workload_key) + print("Best codes (ms):", [float(c) * 1000.0 for c in res.costs]) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 885eb0d1d0f8..75702b0a21af 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -194,7 +194,10 @@ def workload_key_to_tensors(workload_key): assert callable(value) args = deserialize_args(workload[1:]) - return value(*args) + result = value(*args) + if isinstance(result, tuple): + result = list(result) + return result def serialize_workload_registry_entry(workload_key): diff --git a/python/tvm/meta_schedule/builder/local_builder.py b/python/tvm/meta_schedule/builder/local_builder.py index da7bb515f112..ca38424957db 100644 --- a/python/tvm/meta_schedule/builder/local_builder.py +++ b/python/tvm/meta_schedule/builder/local_builder.py @@ -22,13 +22,28 @@ from tvm._ffi import register_func from tvm.ir import IRModule -from tvm.runtime import Module, NDArray, load_param_dict, save_param_dict +from tvm.runtime import NDArray +from tvm.runtime import Module, load_param_dict, save_param_dict from tvm.target import Target from ...contrib.popen_pool import MapResult, PopenPoolExecutor, StatusKind from ..utils import cpu_count, get_global_func_with_default_on_worker from .builder import BuilderInput, BuilderResult, PyBuilder +logger = logging.getLogger(__name__) + + +def _serialize_params(params: Optional[Dict[str, NDArray]]) -> Optional[bytearray]: + if params is None: + return None + return save_param_dict(params) + + +def _deserialize_params(params: Optional[bytearray]) -> Optional[Dict[str, NDArray]]: + if params is None: + return None + return load_param_dict(params) + logger = logging.getLogger(__name__) # pylint: disable=invalid-name @@ -127,7 +142,6 @@ def __init__( The initializer to be used for the worker processes. """ super().__init__() - if max_workers is None: max_workers = cpu_count(logical=True) logger.info("LocalBuilder: max_workers = %d", max_workers) diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f794b11471d9..4fdd80b1769b 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -15,17 +15,19 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule CostModel.""" -import ctypes + from typing import List +import ctypes + +import numpy as np -import numpy as np # type: ignore from tvm._ffi import register_object from tvm.runtime import Object from .. import _ffi_api from ..runner import RunnerResult -from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate from ..utils import _get_hex_address, check_override diff --git a/python/tvm/meta_schedule/cost_model/metric.py b/python/tvm/meta_schedule/cost_model/metric.py index efd8dc68ac0d..7eb6da6f07d9 100644 --- a/python/tvm/meta_schedule/cost_model/metric.py +++ b/python/tvm/meta_schedule/cost_model/metric.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. """Cost model metrics for meta schedule""" -import numpy as np # type: ignore +from typing import List +import numpy as np -def max_curve(trial_scores: np.ndarray) -> np.ndarray: +def max_curve(trial_scores: np.ndarray) -> List[float]: """f(n) = max([s[i] fo i < n]) Parameters @@ -28,8 +29,8 @@ def max_curve(trial_scores: np.ndarray) -> np.ndarray: Returns ------- - curve : np.ndarray - A vector, the max-curve function values + curve : List[float] + function values """ ret = np.empty(len(trial_scores)) keep = -1e9 diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index 8808476aba15..1bb5fc237ae5 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -17,14 +17,14 @@ """ Random cost model """ -from typing import List, Optional, Tuple, Union +from typing import List, Union, Tuple, Optional -import numpy as np # type: ignore +import numpy as np -from ..cost_model import PyCostModel from ..runner import RunnerResult -from ..search_strategy import MeasureCandidate from ..tune_context import TuneContext +from ..search_strategy import MeasureCandidate +from ..cost_model import PyCostModel class RandomModel(PyCostModel): @@ -70,7 +70,7 @@ def load(self, path: str) -> None: path : str The file path. """ - self.random_state = tuple(np.load(path, allow_pickle=True)) # type: ignore + self.random_state = tuple(np.load(path, allow_pickle=True)) def save(self, path: str) -> None: """Save the cost model to given file location. @@ -116,7 +116,7 @@ def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> n The predicted running results. """ np.random.set_state(self.random_state) - # TODO(@zxybazh): Use numpy's RandState object: + # todo(@zxybazh): Use numpy's RandState object: # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState result = np.random.rand(len(candidates)) * self.max_range self.random_state = np.random.get_state() diff --git a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py index d805648bfbfd..d52eda3daac1 100644 --- a/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/random_feature_extractor.py @@ -17,7 +17,7 @@ """Random Feature Extractor.""" from typing import List, Union, Tuple -import numpy as np # type: ignore +import numpy as np from tvm.runtime.ndarray import NDArray, array from ..tune_context import TuneContext diff --git a/python/tvm/meta_schedule/runner/local_runner.py b/python/tvm/meta_schedule/runner/local_runner.py index b1a9c678c6fc..6af403905cb4 100644 --- a/python/tvm/meta_schedule/runner/local_runner.py +++ b/python/tvm/meta_schedule/runner/local_runner.py @@ -33,7 +33,7 @@ run_evaluator_common, ) -logger = logging.getLogger(__name__) # pylint: disable=invalid-name +logger = logging.getLogger(__name__) class LocalRunnerFuture(RunnerFuture): diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 80f372a448f5..a9b2d560314a 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -32,5 +32,5 @@ class PostOrderApply(SpaceGenerator): def __init__(self): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/utils.py b/python/tvm/meta_schedule/utils.py index b6fe34839264..4a63134417e1 100644 --- a/python/tvm/meta_schedule/utils.py +++ b/python/tvm/meta_schedule/utils.py @@ -21,7 +21,7 @@ import shutil from typing import Any, Callable, List, Optional, Union -import psutil # type: ignore +import psutil import tvm from tvm._ffi import get_global_func, register_func from tvm.error import TVMError @@ -66,29 +66,47 @@ def _process_error_message(error_msg: str) -> str: def cpu_count(logical: bool = True) -> int: """Return the number of logical or physical CPUs in the system - Parameters ---------- logical : bool = True If True, return the number of logical CPUs, otherwise return the number of physical CPUs - Returns ------- cpu_count : int The number of logical or physical CPUs in the system - Note ---- The meta schedule search infra intentionally does not adopt the following convention in TVM: - C++ API `tvm::runtime::threading::MaxConcurrency()` - Environment variable `TVM_NUM_THREADS` or - Environment variable `OMP_NUM_THREADS` - This is because these variables are dedicated to controlling the runtime behavior of generated kernels, instead of the host-side search. Setting these variables may interfere the host-side search with profiling of generated kernels when measuring locally. """ + return psutil.cpu_count(logical=logical) or 1 + + +@register_func("meta_schedule._process_error_message") +def _process_error_message(error_msg: str) -> str: + error_msg_lines = str(error_msg).splitlines() + if len(error_msg_lines) >= 50: + return "\n".join(error_msg_lines[:25] + ["..."] + error_msg_lines[-25:]) + return error_msg + + +def cpu_count(logical: bool = True) -> int: + """Return the number of logical or physical CPUs in the system + Parameters + ---------- + logical : bool = True + If True, return the number of logical CPUs, otherwise return the number of physical CPUs + Returns + ------- + cpu_count : int + The number of logical or physical CPUs in the system + """ return _cpu_count_impl(logical) @@ -97,17 +115,14 @@ def get_global_func_with_default_on_worker( default: Callable, ) -> Callable: """Get the registered global function on the worker process. - Parameters ---------- name : Union[None, str, Callable] If given a string, retrieve the function in TVM's global registry; If given a python function, return it as it is; Otherwise, return `default`. - default : Callable The function to be returned if `name` is None. - Returns ------- result : Callable @@ -135,7 +150,6 @@ def get_global_func_on_rpc_session( extra_error_msg: Optional[str] = None, ) -> PackedFunc: """Get a PackedFunc from the global registry from an RPCSession. - Parameters ---------- session : RPCSession @@ -144,7 +158,6 @@ def get_global_func_on_rpc_session( The name of the PackedFunc extra_error_msg : Optional[str] Extra information to provide in the error message - Returns ------- result : PackedFunc @@ -168,12 +181,10 @@ def remove_build_dir(artifact_path: str) -> None: def _json_de_tvm(obj: Any) -> Any: """Unpack a TVM nested container to a JSON object in python. - Parameters ---------- obj : Any The TVM nested container to be unpacked. - Returns ------- result : Any @@ -221,12 +232,10 @@ def batch_json_str2obj(json_strs: List[str]) -> List[Any]: def structural_hash(mod: IRModule) -> str: """Get the structural hash of a module. - Parameters ---------- mod : IRModule The module to be hashed. - Returns ------- result : str @@ -240,11 +249,24 @@ def structural_hash(mod: IRModule) -> str: return str(shash) +def _get_hex_address(handle: ctypes.c_void_p) -> str: + """Get the hexadecimal address of a handle. + Parameters + ---------- + handle : ctypes.c_void_p + The handle to be converted. + Returns + ------- + result : str + The hexadecimal address of the handle. + """ + return hex(ctypes.cast(handle, ctypes.c_void_p).value) + + def check_override( derived_class: Any, base_class: Any, required: bool = True, func_name: str = None ) -> Callable: """Check if the derived class has overridden the base class's method. - Parameters ---------- derived_class : Any @@ -256,7 +278,6 @@ def check_override( func_name : str Name of the method. Default value None, which would be set to substring of the given function, e.g. `f_generate`->`generate`. - Returns ------- func : Callable @@ -278,17 +299,3 @@ def inner(func: Callable): return func return inner - - -def _get_hex_address(handle: ctypes.c_void_p) -> str: - """Get the hexadecimal address of a handle. - Parameters - ---------- - handle : ctypes.c_void_p - The handle to be converted. - Returns - ------- - result : str - The hexadecimal address of the handle. - """ - return hex(ctypes.cast(handle, ctypes.c_void_p).value) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 5cfd3a16c3bc..e9fc10186c87 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -284,13 +284,17 @@ def _module_export(module, file_name): # fcompile, addons, kwargs? @register_func("tvm.relay.build") +def _build_module_no_factory_impl(mod, target, target_host, params, mod_name): + target, target_host = Target.check_and_update_host_consist(target, target_host) + return build(mod, target, params=params, mod_name=mod_name).module + + def _build_module_no_factory(mod, target=None, target_host=None, params=None, mod_name="default"): """A wrapper around build which discards the Python GraphFactoryRuntime. This wrapper is suitable to be used from other programming languages as the runtime::Module can be freely passed between language boundaries. """ - target, target_host = Target.check_and_update_host_consist(target, target_host) - return build(mod, target, params=params, mod_name=mod_name).module + return _build_module_no_factory_impl(mod, target, target_host, params, mod_name) def _reconstruct_from_deprecated_options(deprecated_params_target): diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 42bd52930b1a..475c38b71cb5 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -19,16 +19,16 @@ from typing import Callable, List, Mapping, Union import inspect -import tvm._ffi -import tvm.runtime -from tvm.runtime import Object +from tvm._ffi import get_global_func, register_object from tvm.ir import BaseFunc -from .buffer import Buffer -from .expr import Var, PrimExpr +from tvm.runtime import Object, convert + from . import _ffi_api +from .buffer import Buffer +from .expr import PrimExpr, Var -@tvm._ffi.register_object("tir.PrimFunc") +@register_object("tir.PrimFunc") class PrimFunc(BaseFunc): """A function declaration expression. @@ -57,7 +57,7 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa param_list = [] buffer_map = {} if buffer_map is None else buffer_map for x in params: - x = tvm.runtime.convert(x) if not isinstance(x, Object) else x + x = convert(x) if not isinstance(x, Object) else x if isinstance(x, Buffer): var = Var(x.name, dtype="handle") param_list.append(var) @@ -68,7 +68,13 @@ def __init__(self, params, body, ret_type=None, buffer_map=None, attrs=None, spa raise TypeError("params can only contain Var or Buffer") self.__init_handle_by_constructor__( - _ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs, span # type: ignore + _ffi_api.PrimFunc, # type: ignore # pylint: disable=no-member + param_list, + body, + ret_type, + buffer_map, + attrs, + span, ) def with_body(self, new_body, span=None): @@ -142,7 +148,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: func : PrimFunc The new function with parameter specialized """ - return _ffi_api.Specialize(self, param_map) # type: ignore + return _ffi_api.Specialize(self, param_map) # type: ignore # pylint: disable=no-member def script(self, tir_prefix: str = "T", show_meta: bool = False) -> str: """Print IRModule into TVMScript diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 51cf67f92542..596d70c2d342 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2014,7 +2014,6 @@ def after_tensorize( ########## Schedule: Annotation ########## - @type_checked def annotate( self, block_or_loop: Union[BlockRV, LoopRV], @@ -2031,45 +2030,6 @@ def annotate( The annotation key ann_val : Union[str, int, float, ExprRV, List[Union[str, int, float, ExprRV]]] The annotation value - - Examples - -------- - - Before annotate, in TensorIR, the IR is: - - .. code-block:: python - - @T.prim_func - def before_annotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * 2.0 - - Create the schedule and do annotate: - - .. code-block:: python - - sch = tir.Schedule(before_annotate) - sch.annotate(sch.get_block("B"), "ann_key", "ann_value") - print(sch.mod["main"].script()) - - After applying annotate, the IR becomes: - - .. code-block:: python - - @T.prim_func - def after_annotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) - B[vi, vj] = A[vi, vj] * 2.0 - """ if isinstance(ann_val, str): ann_val = String(ann_val) @@ -2077,11 +2037,10 @@ def after_annotate(a: T.handle, b: T.handle) -> None: ann_val = IntImm("int32", ann_val) elif isinstance(ann_val, float): ann_val = FloatImm("float32", ann_val) - _ffi_api.ScheduleAnnotate( # type: ignore # pylint: disable=no-member + _ffi_api.ScheduleAnnotate( # pylint: disable=no-member self, block_or_loop, ann_key, ann_val ) - @type_checked def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None: """Unannotate a block/loop's annotation with key ann_key @@ -2091,48 +2050,83 @@ def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> Non The block/loop to be unannotated ann_key : str The annotation key + """ + _ffi_api.ScheduleUnannotate(self, block_or_loop, ann_key) # pylint: disable=no-member + + ########## Schedule: Layout transformation ########## + + def transform_layout( + self, + block: BlockRV, + buffer_index: int, + is_write_index: bool, + index_map: Union[IndexMap, Callable], + ) -> None: + """Apply a transformation represented by IndexMap to buffer + + Parameters + ---------- + block_rv : BlockRV + The block that accesses the target buffer + buffer_index: int + The index of the buffer in block's read or write region + is_write_index : bool + Whether the buffer_index is the index of the block's write region + index_map : Union[IndexMap, Callable] + The transformation to apply Examples -------- - Before unannotate, in TensorIR, the IR is: + Before transform_layout, in TensorIR, the IR is: .. code-block:: python @T.prim_func - def before_unannotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) + def before_transform_layout(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((128, 128), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - T.block_attr({"ann_key", "ann_value"}) B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 1.0 - Create the schedule and do annotate: + Create the schedule and do transform_layout: .. code-block:: python - sch = tir.Schedule(before_unannotate) - sch.unannotate(sch.get_block("B"), "ann_key") + sch = tir.Schedule(before_storage_align) + sch.transform_layout(sch.get_block("B"), buffer_index=0, is_write_index=True, + index_map=lambda m, n: (m // 16, n // 16, m % 16, n % 16)) print(sch.mod["main"].script()) - After applying unannotate, the IR becomes: + After applying transform_layout, the IR becomes: .. code-block:: python @T.prim_func - def after_unannotate(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) + def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128), "float32") + B = T.alloc_buffer((8, 8, 16, 16), "float32") + C = T.match_buffer(c, (128, 128), "float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] * 2.0 - + B[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] * 2.0 + for i, j in T.grid(128, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi // 16, vj // 16, vi % 16, vj % 16] + 1.0 """ - _ffi_api.ScheduleUnannotate( # type: ignore # pylint: disable=no-member - self, block_or_loop, ann_key + if callable(index_map): + index_map = IndexMap.from_func(index_map) + _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member + self, block, buffer_index, is_write_index, index_map ) ########## Schedule: Layout transformation ########## diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 0c0ea9cdb3a2..49745cd1a91d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -636,6 +636,18 @@ def PlanAndUpdateBufferAllocationLocation(): return _ffi_api.PlanAndUpdateBufferAllocationLocation() # type: ignore +def ApplyBlockBoundPredicate(): + """Narrow the extents of some loops by checking whether some constraints in the block iter + bound predicates can be directly applied on the loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ApplyBlockBoundPredicate() # type: ignore + + def ConvertBlocksToOpaque(): """Substitute all the block vars with the PrimExprs they are bound to, indicated by the corresponding iter_values in BlockRealize, and then convert the blocks into diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index a4de6592ca13..02e940ea79e3 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -530,12 +530,22 @@ class IterMapRewriter : public ExprMutator { if (predicate_induced_max.defined()) { iter_max = min(predicate_induced_max.value(), iter_max); } - if (!is_zero(iter_min)) { - // structured form's offset should be updated - flattened_map_.erase(structured_form); - structured_form.CopyOnWrite()->base = -iter_min; - mark.CopyOnWrite()->source = structured_form; - flattened_map_[structured_form] = flattened_form; + if (analyzer_->CanProve(iter_min <= iter_max)) { + if (!is_zero(iter_min)) { + // structured form's offset should be updated + flattened_map_.erase(structured_form); + structured_form.CopyOnWrite()->base = -iter_min; + mark.CopyOnWrite()->source = structured_form; + flattened_map_[structured_form] = flattened_form; + } + mark.CopyOnWrite()->extent = iter_max - iter_min; + sum_fuse_map_[flattened_form] = {mark, iter_min}; + // we need to note down the flattened form of constrained iterators + // to check the validity of constraints, see also CheckConstraints() + constrained_iters_flattened_.push_back(flattened_form); + expr.CopyOnWrite()->args = Array({split}); + expr.CopyOnWrite()->base = base + iter_min; + return expr; } mark.CopyOnWrite()->extent = iter_max - iter_min; sum_fuse_map_[flattened_form] = {mark, iter_min}; @@ -611,7 +621,7 @@ class IterMapRewriter : public ExprMutator { } } } - if (!base_scale) { + if (!base_scale || base_scale.value()->value < 0) { diag_ctx_.Emit(Diagnostic::Error(expr->span) << "Fuse iters failed, can not find a valid base scale"); return NullOpt; @@ -890,7 +900,20 @@ bool MatchBoundConstraints(PrimExpr pred, Map& input_iters, iter = lhs_expr; } } - result.emplace_back(iter, lower_bound, upper_bound, 0); + // If it is a predicate for input iters + if (const auto* var_ptr = iter.as()) { + auto it = input_iters.find(GetRef(var_ptr)); + if (it == input_iters.end()) { + return false; + } + PrimExpr iter_min = (*it).second->min; + PrimExpr iter_max = (*it).second->min + (*it).second->extent; + if (lower_bound.defined()) iter_min = max(iter_min, lower_bound.value()); + if (upper_bound.defined()) iter_max = min(iter_max, upper_bound.value()); + input_iters.Set(GetRef(var_ptr), Range(iter_min, iter_max)); + } else { + result.emplace_back(iter, lower_bound, upper_bound, 0); + } if (is_finish) { break; } diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1eac10d1ad82..8c9e2d8949e9 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include "tvm/tir/schedule/schedule.h" namespace tvm { namespace meta_schedule { diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index bd76ca794a9a..0f4aa582c3a8 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -27,17 +27,18 @@ #include #include #include +#include +#include #include #include #include #include #include #include -#include -#include #include -#include +#include +#include #include #include diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 92bd6bd4bf99..e47c33a9d22e 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -70,6 +70,26 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); +/*! + * \brief The information of a block scope, including the leaf blocks, + * as well as the loop types (spatial, reduction) for each loop in the scope. + */ +struct ScopeBlockLoopInfo { + /*! \brief A list of the leaf blocks, from left to right */ + std::vector realizes; + /*! \brief The loop vars bound to spatial block iters */ + std::unordered_set spatial_vars; + /*! \brief The loop vars bound to non-spatial block iters */ + std::unordered_set non_spatial_vars; +}; + +/*! + * \brief Inspect the scope of the given sref + * \param scope_block The root block of the scope + * \return The information of the scope + */ +ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it @@ -235,6 +255,15 @@ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_va */ void CheckAffineBinding(const ScheduleState& self, Block block); +/*! + * \brief Check whether a block has a trivial binding, i.e. each block var is bound to a outer loop, + * from outer to inner. + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block has a trivial binding + */ +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref); + /*! * \brief Extracts the ranges of loop variables in a path of the sref tree * \param low_inclusive The lowest node in the path @@ -618,27 +647,17 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); /*! - * \brief Checks if a producer block could be successfully computed at the specific loop. - * \param self The schedule state - * \param block_sref The block to be moved - * \param loop_sref The loop where the block to be moved to - * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 - * \return A boolean indicating whether the block could be successfully compute at the specific loop + * \brief Provided the access pattern to a buffer, suggest one of the possible layout + * transformation to minimize the locality of the access pattern. + * \param buffer The buffer to be transformed + * \param indices The access pattern to the buffer + * \param loops The loops above the buffer + * \param predicate The predicate of the access + * \param analyzer Arithmetic analyzer */ -bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, - bool preserve_unit_loops); - -/*! - * \brief Checks if a consumer block could be successfully computed at the specific loop. - * \param self The schedule state - * \param block_sref The block to be moved - * \param loop_sref The loop where the block to be moved to - * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 - * \return A boolean indicating whether the block could be successfully reverse compute at the - * specific loop - */ -bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops); +Optional SuggestIndexMap(const Buffer& buffer, const Array& indices, + const Array& loops, const PrimExpr& predicate, + arith::Analyzer* analyzer); /*! * \brief Provided the access pattern to a buffer, suggest one of the possible layout diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index bdb4295e900b..4f642af8b95b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -518,6 +518,22 @@ void CheckAffineBinding(const ScheduleState& self, Block block) { } } +bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + Array loops = GetLoops(block_sref); + Array binds = GetBlockRealize(self, block_sref)->iter_values; + if (loops.size() != binds.size()) { + return false; + } + for (int i = 0, n = loops.size(); i < n; ++i) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loops[i]); + if (binds[i].get() != loop->loop_var.get()) { + return false; + } + } + return true; +} + Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, const Optional& high_exclusive, const runtime::StorageScope& extra_relax_scope) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 3501e7cb723f..60be2efb5245 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -237,7 +237,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding BlockRV: " << block_rv; } - const ObjectRef& obj = (*it).second; + ObjectRef obj = (*it).second; const auto* sref = obj.as(); if (sref == nullptr) { LOG(FATAL) << "ValueError: BlockRV's corresponding type is invalid: " @@ -256,7 +256,7 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { if (it == this->symbol_table_.end()) { LOG(FATAL) << "IndexError: Cannot find corresponding LoopRV: " << loop_rv; } - const ObjectRef& obj = (*it).second; + ObjectRef obj = (*it).second; if (obj.same_as(inline_mark)) { return inline_mark; } diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 14d05a4a340c..71ee09ab6829 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -43,7 +43,7 @@ namespace tir { * * // Convertible to `InstructionKindNode::FInstructionApply` * static Array ApplyToSchedule( - * const tir::Schedule& sch, + * const Schedule& sch, * const Array& inputs, * const Array& attrs, * const Optional& decision); diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index b445b5a9ded8..4ad09ab3dfdf 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -22,6 +22,7 @@ #include #include +#include #include namespace tvm { @@ -440,7 +441,6 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int * \param ann_key The annotation key */ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key); - /******** Schedule: Misc ********/ } // namespace tir diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index f5c1978a1b25..4ed40817132d 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -116,8 +116,7 @@ struct AnnotateTraits : public UnpackedInstTraits { return py.Str(); } - template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct UnpackedInstTraits; }; struct UnannotateTraits : public UnpackedInstTraits { @@ -148,8 +147,7 @@ struct UnannotateTraits : public UnpackedInstTraits { return py.Str(); } - template - friend struct ::tvm::tir::UnpackedInstTraits; + friend struct UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(AnnotateTraits); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index b811afb23614..1ed5bdc03f51 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -25,6 +25,49 @@ using support::NDIntSet; /******** Error Classes ********/ +/*! + * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound) + * The bound region may not get directly intersected with dom region, instead we try to generate + * extra predicates for non-trivial bound. The domain info class can also union with each other. + */ +struct BlockVarDomainInfo { + arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded + arith::IntSet bound{arith::IntSet::Nothing()}; + + /*! \brief Relaxed union operation */ + void Union(const BlockVarDomainInfo& other) { + // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1) + dom = arith::Union({dom, other.dom}); + bound = arith::Union({bound, other.bound}); + } + + /*! \brief Simplify domain info */ + void Simplify(arith::Analyzer* analyzer) { + auto to_simplified = [analyzer](const arith::IntSet& set) { + PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min(); + PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max(); + return arith::IntSet::Interval(min, max); + }; + // if no dom specified, try use bound as dom + if (dom.IsNothing()) { + if (bound.HasLowerBound() && bound.HasUpperBound()) { + bound = to_simplified(bound); + std::swap(dom, bound); + } + return; + } + // simplify intsets + dom = to_simplified(dom); + bound = to_simplified(bound); + // if can proof the dom is within bound, remove bound + auto intersect = to_simplified(arith::Intersect({dom, bound})); + if (analyzer->CanProveEqual(dom.min(), intersect.min()) && + analyzer->CanProveEqual(dom.max(), intersect.max())) { + bound = arith::IntSet::Nothing(); + } + } +}; + /*! * \brief An error raised when not all required blocks are under the given loop. * \tparam is_consumer Indicates if all the required blocks are consumers or producers @@ -317,6 +360,8 @@ class ScopeReconstructor : private StmtMutator { Stmt rm_src_stmt_{nullptr}; /*! \brief The plan to remove the given block by replacing to this loop/block in the AST */ Stmt rm_tgt_stmt_{nullptr}; + /*! \brief Bound predicate for the given block to be moved */ + Optional predicate{NullOpt}; }; /*! @@ -547,9 +592,11 @@ void CalculateProvidedRequiredRegions( /******** Main Implementation ********/ template -void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops, - arith::Analyzer* analyzer, bool check_only = false) { +std::function ComputeAtOrReverseComputeAtImpl(ScheduleState self, + const StmtSRef& block_sref, + const StmtSRef& loop_sref, + bool preserve_unit_loops, + arith::Analyzer* analyzer) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks @@ -604,32 +651,35 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops); Block new_scope_root = Downcast(reconstructor(scope_root)); - - // Step 7. Do the actual replacement - if (check_only) { - return; - } - self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); - // Step 8. Update the cached flags - BlockInfo& block_info = self->block_info[block_sref]; - block_info.affine_binding = IsAffineBinding( - /*realize=*/reconstructor.new_block_realize_, - /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), - /*analyzer=*/analyzer); + Optional bound_predicate = reconstructor.predicate; + return [=]() -> void { + // Step 7. Do the actual replacement + self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); + // Step 8. Update the cached flags + BlockInfo& block_info = self->block_info[block_sref]; + block_info.affine_binding = IsAffineBinding( + /*realize=*/reconstructor.new_block_realize_, + /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), + /*analyzer=*/analyzer); + // Step 9. Add bound predicate annotation for the block to be moved if needed + if (bound_predicate.defined()) { + Annotate(self, block_sref, attr::require_block_var_bound_predicate, bound_predicate.value()); + } + }; } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer)(); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { arith::Analyzer analyzer; ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer); + &analyzer)(); } bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, @@ -637,7 +687,7 @@ bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const S arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + &analyzer); } catch (const tvm::runtime::Error& e) { return false; } @@ -649,7 +699,7 @@ bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, arith::Analyzer analyzer; try { ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, - &analyzer, true); + &analyzer); } catch (const tvm::runtime::Error& e) { return false; } diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 03ffb4fe159e..98106d51f4db 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -490,7 +490,6 @@ class LoopPropertyError : public ScheduleError { CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); meet_reduction_loop = true; } - continue; } else if (meet_reduction_loop && !is_one(loop->extent)) { throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); } @@ -591,8 +590,8 @@ class BaseBlockCreator { } private: - virtual void CreateAdditionalIter() = 0; virtual void CreateNormalIters(int idx) = 0; + virtual void CreateAdditionalIter() = 0; virtual void CreateReductionUpdate() = 0; virtual void CreateReadWriteRegions() = 0; @@ -825,6 +824,13 @@ class WriteBackBlockCreator : public BaseBlockCreator { } } + void CreateAdditionalIter() final { + additional_iter_ = IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.insert(iter_vars_.end(), additional_iter_); + iter_values_.insert(iter_values_.end(), rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, additional_iter_->var); + } + void CreateReductionUpdate() final { wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); wb_rhs_ = diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 0e767825573f..9e1658a61768 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -86,6 +86,7 @@ struct PrimeTable { pow_tab.emplace_back(std::move(tab)); } } + /*! * \brief Factorize a number n, and return in a cryptic format * \param n The number to be factorized @@ -299,27 +300,17 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS return SamplePerfectTile(rand_state, extent, n_splits); } CHECK_GE(n_splits, 2) << "ValueError: Cannot tile a loop into " << n_splits << " splits"; - std::vector innermost_candidates; - innermost_candidates.reserve(max_innermost_factor); - for (int32_t i = 1; i <= max_innermost_factor; ++i) { - if (extent % i == 0) { - innermost_candidates.push_back(i); + while (true) { + std::vector result = SamplePerfectTile(rand_state, extent, n_splits); + if (result.back() <= max_innermost_factor) { + return result; } } - // N.B. Theoretically sampling evenly breaks the uniform sampling of the global sampling space. - // We should do multiple factorization to weight the choices. However, it would lead to slower - // sampling speed. On the other hand, considering potential tricks we might do on the innermost - // loop, in which sampling uniformly does not help, let's leave it as it is for now, and maybe add - // more heuristics in the future - int32_t innermost = innermost_candidates[SampleInt(rand_state, 0, innermost_candidates.size())]; - std::vector result = SamplePerfectTile(rand_state, extent / innermost, n_splits - 1); - result.push_back(innermost); - return result; } std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); const int64_t* extent = GetLoopIntExtent(loop); diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index eb43157d805a..2624afa476e0 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -339,6 +339,10 @@ class BlockInfoCollector : private StmtVisitor { /*dom_low_inclusive=*/parent_sref, /*dom_high_exclusive=*/lca, /*analyzer=*/&analyzer_); + for (size_t i = 0; i < consumed_region.size(); ++i) { + const arith::IntSet consumed_interset = arith::Intersect( + {consumed_region[i], arith::IntSet::FromMinExtent(0, buffer->shape[i])}); + } if (!ProducerCoversConsumer(buffer->shape, produced_region, consumed_region, &analyzer_)) { region_cover = false; @@ -898,7 +902,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const StmtSRef& _src_sref, const Stmt& tgt_stmt, const Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/tests/python/unittest/test_meta_schedule_byoc.py b/tests/python/unittest/test_meta_schedule_byoc.py new file mode 100644 index 000000000000..fe50350d5133 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_byoc.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test Meta Schedule Builder """ +# pylint: disable=missing-docstring + +import sys + +import pytest +import tvm +from tvm import relay +from tvm.meta_schedule.arg_info import TensorInfo +from tvm.meta_schedule.builder import BuilderInput, LocalBuilder +from tvm.meta_schedule.runner import EvaluatorConfig, LocalRunner, RunnerInput +from tvm.meta_schedule.testing import get_network +from tvm.meta_schedule.testing.byoc_trt import ( + build_relay, + build_relay_with_tensorrt, + run_with_graph_executor, +) +from tvm.relay import testing +from tvm.relay.op.contrib import tensorrt +from tvm.target import Target +from tvm.tir import FloatImm + +has_tensorrt_codegen = pytest.mark.skipif( + not tvm.get_global_func("relay.ext.tensorrt", True), reason="TensorRT codegen not available" +) +has_tensorrt_runtime = pytest.mark.skipif( + not tensorrt.is_tensorrt_runtime_enabled(), reason="TensorRT runtime not available" +) + +# conv2d+relu network +def get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, +): + + data = relay.var("data", relay.TensorType(data_shape, dtype)) + weight = relay.var("weight") + + net = relay.nn.conv2d( + data=data, + weight=weight, # conv kernel + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + channels=out_channels, + kernel_size=kernel_size, + data_layout=data_layout, + kernel_layout=kernel_layout, + ) + net = relay.add(net, net) + net = relay.nn.relu(net) + + inputs = relay.analysis.free_vars(net) + return relay.Function(inputs, net) + + +def verify_meta_schedule_with_tensorrt( + mod, + params, + data_shape, + use_meta_sched: bool = True, + use_trt: bool = True, + mode: str = "vm", +): + if use_meta_sched: + # With meta_schedule + dev = "nvidia/geforce-rtx-2080" + # Build + builder = LocalBuilder( + f_build=build_relay_with_tensorrt if use_trt else build_relay, + timeout_sec=1000, + ) + builder_input = BuilderInput(mod, Target(dev, host="llvm"), params) + builder_result = builder.build([builder_input])[0] + assert builder_result.error_msg is None, builder_result.error_msg + assert builder_result.artifact_path is not None + + # Run + runner_input = RunnerInput( + builder_result.artifact_path, + device_type="cuda", + args_info=[TensorInfo("float32", data_shape)], + ) + runner = LocalRunner( + evaluator_config=EvaluatorConfig( + number=5, + repeat=2, + min_repeat_ms=0, + enable_cpu_cache_flush=False, + ), + f_run_evaluator=run_with_graph_executor, + ) + + # Run the module + runner_future = runner.run([runner_input])[0] + runner_result = runner_future.result() + assert runner_result is not None + assert runner_result.error_msg is None, runner_result.error_msg + assert runner_result.run_secs is not None + + for result in runner_result.run_secs: + if isinstance(result, FloatImm): + result = result.value + assert isinstance(result, float) + assert result >= 0.0 + + else: + # Without meta_schedule + if use_trt: + mod, config = tensorrt.partition_for_tensorrt(mod) + with tvm.transform.PassContext( + opt_level=3, config={"relay.ext.tensorrt.options": config} + ): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda" + ).evaluate() + else: + with tvm.transform.PassContext(opt_level=3): + _func = relay.create_executor( + mode, mod=mod, device=tvm.cuda(0), target="cuda", params=params + ).evaluate() + + +@has_tensorrt_codegen +def test_conv2d_relu(): + data_shape = (1, 1280, 14, 14) + out_channels = 256 + kernel_size, strides, padding, dilation, groups = (1, 1), (1, 1), (0, 0, 0, 0), (1, 1), 1 + data_layout, kernel_layout = "NCHW", "OIHW" + dtype = "float32" + + f = get_conv2d_relu( + data_shape, + out_channels, + kernel_size, + strides, + padding, + dilation, + groups, + data_layout, + kernel_layout, + dtype, + ) + + mod, params = testing.create_workload(f) + verify_meta_schedule_with_tensorrt(mod, params, data_shape) + + +@has_tensorrt_codegen +@pytest.mark.parametrize( + "model_name", + ["resnet-50", "mobilenet"], +) +@pytest.mark.parametrize("batch_size", [1, 8]) +@pytest.mark.parametrize("use_meta_sched", [True]) +@pytest.mark.parametrize("use_trt", [True, False]) +def test_relay_model(model_name: str, batch_size: int, use_meta_sched: bool, use_trt: bool): + mod, params, input_shape, _oshape = get_network( + name=model_name, + batch_size=batch_size, + ) + verify_meta_schedule_with_tensorrt( + mod, + params, + input_shape, + use_meta_sched=use_meta_sched, + use_trt=use_trt, + mode="vm", + ) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_meta_schedule_cost_model.py b/tests/python/unittest/test_meta_schedule_cost_model.py index 4cb018b29aa4..c939792b55ff 100644 --- a/tests/python/unittest/test_meta_schedule_cost_model.py +++ b/tests/python/unittest/test_meta_schedule_cost_model.py @@ -14,15 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +from typing import List + +import tempfile import os import re -import shutil import sys -import tempfile -from typing import List - -import numpy as np +import shutil import pytest import tvm @@ -34,6 +32,11 @@ from tvm.meta_schedule.tune_context import TuneContext from tvm.script import tir as T from tvm.tir.schedule.schedule import Schedule +from tvm.meta_schedule.search_strategy import MeasureCandidate +from tvm.meta_schedule.runner import RunnerResult +from tvm.meta_schedule.feature_extractor import RandomFeatureExtractor +from tvm.meta_schedule.cost_model import PyCostModel, RandomModel, XGBModel +from tvm.meta_schedule.tune_context import TuneContext # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @tvm.script.ir_module diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 95cf6ebaeb43..348339aee2e0 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,6 +22,7 @@ import pytest import tvm +from tvm.tir.schedule import BlockRV, Schedule from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule diff --git a/tests/python/unittest/test_meta_schedule_space_generator.py b/tests/python/unittest/test_meta_schedule_space_generator.py index 49a3f6309183..3eb050db3baa 100644 --- a/tests/python/unittest/test_meta_schedule_space_generator.py +++ b/tests/python/unittest/test_meta_schedule_space_generator.py @@ -23,6 +23,9 @@ import pytest import tvm +from tvm._ffi.base import TVMError +from tvm.ir.module import IRModule +from tvm.meta_schedule.space_generator.space_generator import PySpaceGenerator from tvm.script import tir as T from tvm.tir.schedule import Schedule from tvm.meta_schedule.space_generator import ScheduleFn, PySpaceGenerator, SpaceGeneratorUnion diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py b/tests/python/unittest/test_tir_schedule_compute_at.py index e1cf399d49a1..f0f5051f5c33 100644 --- a/tests/python/unittest/test_tir_schedule_compute_at.py +++ b/tests/python/unittest/test_tir_schedule_compute_at.py @@ -755,8 +755,9 @@ def read_out_of_bound_after_compute_at(a: T.handle, c: T.handle) -> None: T.where(j + i < 16) B[v] = A[v] with T.block("C"): - v = T.axis.S(16, j) - T.reads([B[v : v + 2]]) + v = T.axis.spatial(16, j) + T.reads(B[v : v + 2]) + T.writes(C[v]) C[v] = T.if_then_else(v < 15, T.max(B[v], B[v + 1]), B[v], dtype="float32") @@ -1253,5 +1254,28 @@ def test_fail_all_producers_under_loop(): sch.reverse_compute_at(block, loop) +def test_compute_at_tiled_pooling_cache(): + sch = tir.Schedule(tiled_pooling_cache, debug_mask="all") + compute = sch.get_block("compute") + _, w_o, _, _, _, _ = sch.get_loops(compute) + cache = sch.get_block("cache") + dache = sch.get_block("dache") + sch.compute_at(cache, w_o) + sch.compute_at(dache, w_o) + tvm.ir.assert_structural_equal(tiled_pooling_cache_after_compute_at, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=tiled_pooling_cache) + + +def test_reverse_compute_at_floordiv_and_floormod_indices(): + sch = tir.Schedule(floordiv_and_floormod_indices, debug_mask="all") + A = sch.get_block("A") + B = sch.get_block("B") + sch.reverse_compute_at(B, sch.get_loops(A)[0]) + tvm.ir.assert_structural_equal( + floordiv_and_floormod_indices_after_reverse_compute_at, sch.mod["main"] + ) + verify_trace_roundtrip(sch=sch, mod=floordiv_and_floormod_indices) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index cc2b114824a5..cf9621dc1d4c 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -25,7 +25,7 @@ from tvm.tir.schedule.testing import verify_trace_roundtrip -# pylint: disable=no-member,invalid-name,unused-variable +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long @T.prim_func