diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index aaa0e9c9174d..cb039cf07d5f 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -18,3 +18,5 @@ """Namespace for Ansor autoSchedule""" from .compute_dag import ComputeDAG +from .task import SearchTask +from .measure import MeasureInput, LocalBuilder, LocalRunner diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index 3c46440f75ba..a66a181f054c 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -25,10 +25,56 @@ from . import _ffi_api +class LayoutRewriteLevel(object): + NO_REWRITE = 0 # No layout rewrite + PLACEHOLDER_REWRITE = 1 # Only rewrite layout of placeholder in the compute dag + COMPUTE_REWRITE = 2 # Only rewrite compute body for new layout in the compute dag + BOTH_REWRITE = 3 # Rewrite both placeholder and compute body in the compute dag + + @tvm._ffi.register_object("ansor.ComputeDAG") class ComputeDAG(Object): + """ + Parameters + ---------- + tensors : List[Tensor] + """ + def __init__(self, tensors): self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, tensors) - def get_init_state(self) -> State: - return self.init_state + def get_init_state(self): + """ Get init state of this ComputeDAG + + Returns + ------- + state : State + """ + return _ffi_api.ComputeDAGGetInitState(self) + + def apply_steps_from_state(self, state, layout_rewrite_level): + """ + Parameters + ---------- + state : State + layout_rewrite_level : LayoutRewriteLevel(***) + + Returns + ------- + sch : Schedule + args : List[Tensor] + """ + sch, args = _ffi_api.ComputeDAGApplyStepsFromState(self, state) + return sch, args + + def print_python_code_from_state(self, state): + """ + Parameters + ---------- + state : State + + Returns + ------- + str : Str + """ + return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state) diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py new file mode 100644 index 000000000000..72dd3cbfcf92 --- /dev/null +++ b/python/tvm/ansor/measure.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Distributed measurement infrastructure to measure the runtime costs of tensor programs + +These functions are responsible for building the tvm module, uploading it to +remote devices, recording the running time costs, and checking the correctness of the output. + +We implement these in python to utilize python's multiprocessing and error handling +""" +from typing import List +import os +import time +import shutil +import logging +import traceback +import tempfile +import multiprocessing + +import tvm._ffi +from tvm.runtime import Object, module, ndarray +from tvm.driver import build_module +from tvm.target import build_config +from ..contrib import tar, ndk +from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout, request_remote, check_remote +from .compute_dag import LayoutRewriteLevel + +from . import _ffi_api + +logger = logging.getLogger('ansor') + + +@tvm._ffi.register_object("ansor.MeasureInput") +class MeasureInput(Object): + """ + Parameters + ---------- + task : SearchTask + state : State + """ + + def __init__(self, task, state): + self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state) + + +@tvm._ffi.register_object("ansor.BuildResult") +class BuildResult(Object): + """ + Parameters + ---------- + filename : Str + args : List[Tensor] + error_no : Int + error_msg : Str + time_cost : Float + """ + + def __init__(self, filename, args, error_no, error_msg, time_cost): + self.__init_handle_by_constructor__( + _ffi_api.BuildResult, filename, args, error_no, + error_msg if error_msg else "", time_cost) + + +@tvm._ffi.register_object("ansor.MeasureResult") +class MeasureResult(Object): + """ + Parameters + ---------- + costs : List[Float] + error_no : Int + error_msg : Str + all_cost : Float + timestamp : Float + """ + + def __init__(self, costs, error_no, error_msg, all_cost, timestamp): + self.__init_handle_by_constructor__( + _ffi_api.MeasureResult, costs, error_no, + error_msg if error_msg else "", all_cost, timestamp) + + +@tvm._ffi.register_object("ansor.Builder") +class Builder(Object): + def build(self, measure_inputs, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + verbost : Int + + Returns + ------- + res : List[BuildResult] + """ + return _ffi_api.BuilderBuild(self, measure_inputs, verbose) + + +@tvm._ffi.register_object("ansor.Runner") +class Runner(Object): + def run(self, measure_inputs, build_results, verbose=0): + """ + Parameters + ---------- + measure_inputs : List[MeasureInput] + build_results : List[BuildResult] + + Returns + ------- + res : List[MeasureResult] + """ + return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose) + + +@tvm._ffi.register_object("ansor.LocalBuilder") +class LocalBuilder(Builder): + """ + Parameters + ---------- + timeout : Int + n_parallel : Int + build_func : Str + """ + + def __init__(self, + timeout=15, + n_parallel=multiprocessing.cpu_count(), + build_func='default'): + self.__init_handle_by_constructor__( + _ffi_api.LocalBuilder, timeout, n_parallel, build_func) + + +@tvm._ffi.register_object("ansor.LocalRunner") +class LocalRunner(Runner): + """ + Parameters + ---------- + timeout : Int + number : Int + repeat : Int + min_repeat_ms : Int + cooldown_interval : Float + """ + + def __init__(self, + timeout=10, + number=3, + repeat=1, + min_repeat_ms=0, + cooldown_interval=0.0): + self.__init_handle_by_constructor__( + _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) + + +MAX_ERROR_MSG_LEN = 512 + + +class MeasureErrorNo(object): + """Error type for MeasureResult""" + NO_ERROR = 0 # No error + INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state + # Errors happen when compiling code on host (e.g. tvm.build) + COMPILE_HOST = 2 + COMPILE_DEVICE = 3 # Errors happen when compiling code on device + # (e.g. OpenCL JIT on the device) + RUNTIME_DEVICE = 4 # Errors happen when run program on device + WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output + BUILD_TIMEOUT = 6 # Timeout during compilation + RUN_TIMEOUT = 7 # Timeout during run + UNKNOWN_ERROR = 8 # Unknown error + + +def make_error_msg(): + error_msg = str(traceback.format_exc()) + if len(error_msg) > MAX_ERROR_MSG_LEN: + error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ + "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:] + return error_msg + + +global global_build_arguments +global global_run_arguments + + +def local_build_worker(index): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + measure_inputs, build_func, timeout, verbose = global_build_arguments + assert isinstance(build_func, str) + if build_func == 'default': + build_func = tar.tar + elif build_func == 'ndk': + build_func = ndk.create_shared + else: + raise ValueError("Invalid build_func" + build_func) + + def timed_func(): + tic = time.time() + inp = measure_inputs[index] + task = inp.task + + error_no = MeasureErrorNo.NO_ERROR + error_msg = None + args = [] + + try: + sch, args = task.compute_dag.apply_steps_from_state( + inp.state, LayoutRewriteLevel.BOTH_REWRITE) + except Exception: + error_no = MeasureErrorNo.INSTANTIATION_ERROR + error_msg = make_error_msg() + + if error_no == 0: + dirname = tempfile.mkdtemp() + filename = os.path.join( + dirname, "tmp_func." + build_func.output_format) + + try: + with build_config(unroll_max_extent=task.hardware_params.max_unroll_vec): + func = build_module.build( + sch, args, target=task.target, target_host=task.target_host) + func.export_library(filename, build_func) + except Exception: + error_no = MeasureErrorNo.COMPILE_HOST + error_msg = make_error_msg() + else: + filename = "" + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print(".", end="") + else: + print(".E", end="") # Build error + return filename, args, error_no, error_msg, time.time() - tic + + res = call_func_with_timeout(timeout, timed_func) + if isinstance(res, TimeoutError): + if verbose >= 1: + print(".T", end="") # Build timeout + res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout + + return res + + +@tvm._ffi.register_func("ansor.local_builder.build") +def local_builder_build(inputs: List[MeasureInput], timeout: float, n_parallel: int, build_func: str, verbose: int): + # We use fork to copy arguments from a global variable. + # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool + global global_build_arguments + global_build_arguments = (inputs, build_func, timeout, verbose) + + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(local_build_worker, range(len(inputs))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(BuildResult(*res)) + + return results + + +@tvm._ffi.register_func("ansor.rpc_runner.run") +def rpc_runner_run(inputs: List[MeasureInput], build_results: List[BuildResult], + key: str, host: str, port: int, priority: int, timeout: float, + n_parallel: int, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + global global_run_arguments + global_run_arguments = (inputs, build_results, key, host, port, priority, timeout, number, + repeat, min_repeat_ms, cooldown_interval, verbose) + + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + pool = NoDaemonPool(n_parallel) + tuple_res = pool.map(rpc_run_worker, range(len(build_results))) + pool.terminate() + pool.join() + del pool + + results = [] + for res in tuple_res: + results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return results + + +def rpc_run_worker(index): + inputs, build_results, key, host, port, priority, timeout, number, \ + repeat, min_repeat_ms, cooldown_interval, verbose = global_run_arguments + + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + inp = inputs[index] + build_res = build_results[index] + + if build_res.error_no != MeasureErrorNo.NO_ERROR: + return (MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + + def timed_func(): + tic = time.time() + error_no = 0 + error_msg = None + try: + # upload built module + remote = request_remote(key, host, port, priority, timeout) + remote.upload(build_res.filename) + func = remote.load_module(os.path.split(build_res.filename)[1]) + ctx = remote.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + # clean up remote files + remote.remove(build_res.filename) + remote.remove(os.path.splitext(build_res.filename)[0] + '.so') + remote.remove('') + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + + time.sleep(cooldown_interval) + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + res = call_func_with_timeout(timeout, timed_func) + + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = (MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + \ + timeout, time.time() + return res + + +@tvm._ffi.register_func("ansor.local_runner.run") +def local_run(inputs: List[MeasureInput], build_results: List[BuildResult], + timeout: float, number: int, repeat: int, min_repeat_ms: int, + cooldown_interval: float, verbose: int): + MAX_FLOAT = 1e10 # We use 1e10 instead of sys.float_info.max for better readability in log + + def timed_func(inp, build_res): + tic = time.time() + error_no = 0 + error_msg = None + try: + func = module.load_module(build_res.filename) + ctx = ndarray.context(str(inp.task.target), 0) + time_f = func.time_evaluator( + func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms) + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.COMPILE_DEVICE + error_msg = make_error_msg() + + if error_no == 0: + try: + args = [ndarray.empty(get_const_tuple(x.shape), x.dtype, ctx) for x in + build_res.args] + ctx.sync() + + costs = time_f(*args).results + except Exception: + costs = (MAX_FLOAT,) + error_no = MeasureErrorNo.RUNTIME_DEVICE + error_msg = make_error_msg() + + shutil.rmtree(os.path.dirname(build_res.filename)) + toc = time.time() + time.sleep(cooldown_interval) + + if verbose >= 1: + if error_no == MeasureErrorNo.NO_ERROR: + print("*", end="") + else: + print("*E", end="") # Run error + return costs, error_no, error_msg, toc - tic + build_res.time_cost, toc + + measure_results = [] + assert len(inputs) == len(build_results), \ + "Measure input size should be equal to build results" + for inp, build_res in zip(inputs, build_results): + if build_res.error_no != 0: + res = ( + MAX_FLOAT,), build_res.error_no, build_res.error_msg, build_res.time_cost, time.time() + else: + res = call_func_with_timeout( + timeout, timed_func, args=(inp, build_res)) + if isinstance(res, TimeoutError): + if verbose >= 1: + print("*T", end="") # Run timeout + res = ( + MAX_FLOAT,), MeasureErrorNo.RUN_TIMEOUT, None, build_res.time_cost + timeout, time.time() + measure_results.append(MeasureResult(*res)) + + if verbose >= 1: + print("") + + return measure_results diff --git a/python/tvm/ansor/state.py b/python/tvm/ansor/state.py index 9a8810190199..7de95a8a74af 100644 --- a/python/tvm/ansor/state.py +++ b/python/tvm/ansor/state.py @@ -25,21 +25,41 @@ @tvm._ffi.register_object("ansor.Iterator") class Iterator(Object): + """ ... + """ pass @tvm._ffi.register_object("ansor.Stage") class Stage(Object): + """ ... + """ def iterator(self, index): + """ + Parameters + ---------- + index : Int + + Returns + ------- + iter : Iterator + """ return _ffi_api.StageGetIterator(self, index) def iterators(self): + """ + Returns + ------- + iters : List[Iterator] + """ return _ffi_api.StageGetIterators(self) @tvm._ffi.register_object("ansor.State") class State(Object): + """ ... + """ def stage(self, index): """ @@ -93,10 +113,12 @@ def split(self, stage_id, it, lengths, inner_to_outer=True): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateSplit(self, stage_id, it, lengths, - inner_to_outer) - return state + state, res_its = _ffi_api.StateSplit(self, stage_id, it, lengths, + inner_to_outer) + return state, res_its def follow_split(self, stage_id, it, src_step_id, n_split): """ @@ -115,10 +137,12 @@ def follow_split(self, stage_id, it, src_step_id, n_split): ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowSplit(self, stage_id, it, src_step_id, - n_split) - return state + state, res_its = _ffi_api.StateFollowSplit(self, stage_id, it, + src_step_id, n_split) + return state, res_its def follow_fused_split(self, stage_id, it, src_step_ids, level, factor_or_nparts): @@ -140,10 +164,13 @@ def follow_fused_split(self, stage_id, it, src_step_ids, level, ------- state : State The updated state + res_its : List[Iterator] + The splited Iterators result """ - state = _ffi_api.StateFollowFusedSplit(self, stage_id, it, src_step_ids, - level, factor_or_nparts) - return state + state, res_its = _ffi_api.StateFollowFusedSplit(self, stage_id, it, + src_step_ids, level, + factor_or_nparts) + return state, res_its def fuse(self, stage_id, iters): """ @@ -158,9 +185,11 @@ def fuse(self, stage_id, iters): ------- state : State The updated state + res_it : Iterator + The fused Iterator """ - state = _ffi_api.StateFuse(self, stage_id, iters) - return state + state, res_it = _ffi_api.StateFuse(self, stage_id, iters) + return state, res_it def vectorize(self, stage_id, it): """ @@ -175,9 +204,11 @@ def vectorize(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The vectorized Iterator """ - state = _ffi_api.StateVectorize(self, stage_id, it) - return state + state, res_it = _ffi_api.StateVectorize(self, stage_id, it) + return state, res_it def parallel(self, stage_id, it): """ @@ -192,9 +223,11 @@ def parallel(self, stage_id, it): ------- state : State The updated state + res_it : Iterator + The paralleled Iterator """ - state = _ffi_api.StateParallel(self, stage_id, it) - return state + state, res_it = _ffi_api.StateParallel(self, stage_id, it) + return state, res_it def unroll(self, stage_id, it, max_unroll=-1): """ @@ -210,9 +243,11 @@ def unroll(self, stage_id, it, max_unroll=-1): ------- state : State The updated state + res_it : Iterator + The unrolled Iterator """ - state = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) - return state + state, res_it = _ffi_api.StateUnroll(self, stage_id, it, max_unroll) + return state, res_it def bind_thread(self, stage_id, it, thread_type): """ @@ -229,9 +264,12 @@ def bind_thread(self, stage_id, it, thread_type): ------- state : State The updated state + res_it : Iterator + The thread binded Iterator """ - state = _ffi_api.StateBindThread(self, stage_id, it, thread_type) - return state + state, res_it = _ffi_api.StateBindThread(self, stage_id, it, + thread_type) + return state, res_it def compute_at(self, stage_id, target_stage_id, target_iter): """ @@ -311,10 +349,12 @@ def cache_read(self, stage_id, scope_name, reader_stage_ids, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheRead(self, stage_id, scope_name, - reader_stage_ids, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheRead(self, stage_id, + scope_name, reader_stage_ids, task_dag) + return state, int(new_stage_id) def cache_write(self, stage_id, scope_name, task_dag): """ @@ -329,9 +369,12 @@ def cache_write(self, stage_id, scope_name, task_dag): ------- state : State The updated state + new_stage_id : Int + The added staged id """ - state = _ffi_api.StateCacheWrite(self, stage_id, scope_name, task_dag) - return state + state, new_stage_id = _ffi_api.StateCacheWrite(self, stage_id, + scope_name, task_dag) + return state, int(new_stage_id) def pragma(self, stage_id, it, pragma_type): """ diff --git a/python/tvm/ansor/task.py b/python/tvm/ansor/task.py new file mode 100644 index 000000000000..245cf4c727ae --- /dev/null +++ b/python/tvm/ansor/task.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +""" ... """ + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + +@tvm._ffi.register_object("ansor.HardwareParams") +class HardwareParams(Object): + """ + Parameters + ---------- + num_cores : Int + vector_unit_bytes : Int + cache_line_bytes : Int + max_unroll_vec : Int + max_innermost_split_factor : Int + """ + + def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes, + max_unroll_vec, max_innermost_split_factor): + self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores, + vector_unit_bytes, cache_line_bytes, max_unroll_vec, + max_innermost_split_factor) + + +@tvm._ffi.register_object("ansor.SearchTask") +class SearchTask(Object): + """ + Parameters + ---------- + dag : ComputeDAG + workload_key : Str + target : tvm.target + target_host : tvm.target + hardware_params : HardwareParams + """ + + def __init__(self, dag, workload_key, target, target_host=None, + hardware_params=None): + self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag, + workload_key, target, target_host, hardware_params) diff --git a/python/tvm/ansor/utils.py b/python/tvm/ansor/utils.py new file mode 100644 index 000000000000..0216549c184a --- /dev/null +++ b/python/tvm/ansor/utils.py @@ -0,0 +1,229 @@ +"""Common utilities""" +import multiprocessing +import multiprocessing.pool +import queue +import signal +import threading +import os + +import numpy as np + +try: + import psutil +except ImportError: + psutil = None + +from .. import rpc as _rpc +from tvm.tir import expr +from tvm.tir.transform import Simplify +from tvm.ir.transform import Sequential + + +def get_func_name(func): + """Get name of a function + + Parameters + ---------- + func: Function + The function + Returns + ------- + name: str + The name + """ + + return func.func_name if hasattr(func, 'func_name') else func.__name__ + + +def get_const_int(exp): + """Verifies expr is integer and get the constant value. + + Parameters + ---------- + exp : tvm.Expr or int + The input expression. + + Returns + ------- + out_value : int + The output. + """ + if isinstance(exp, int): + return exp + if not isinstance(exp, (expr.IntImm)): + opt = Sequential([Simplify()]) + exp = opt(exp) + if not isinstance(exp, (expr.IntImm)): + raise ValueError("Expect value to be constant int") + return exp.value + + +def get_const_tuple(in_tuple): + """Verifies input tuple is IntImm, returns tuple of int. + + Parameters + ---------- + in_tuple : tuple of Expr + The input. + + Returns + ------- + out_tuple : tuple of int + The output. + """ + return tuple(get_const_int(x) for x in in_tuple) + + +def to_str_round(x, decimal=6): + """Convert object to str and round float numbers""" + if isinstance(x, str): + return x + if isinstance(x, (list, tuple)) or isinstance(x, np.ndarray): + return "[" + ", ".join([to_str_round(y, decimal=decimal) + for y in x]) + "]" + if isinstance(x, dict): + return str({k: eval(to_str_round(v)) for k, v in x.items()}) + if isinstance(x, int): + return str(x) + if isinstance(x, (np.float32, np.float64, float)): + format_str = "%%.%df" % decimal + return format_str % x + raise ValueError("Invalid value: " + str(x) + "\ttype: " + str(type(x))) + + +def array_mean(arr): + """Mean function for tvm array (Array)""" + return sum(x.value for x in arr) / len(arr) + + +class NoDaemonProcess(multiprocessing.Process): + @property + def daemon(self): + return False + + @daemon.setter + def daemon(self, value): + pass + + +class NoDaemonContext(type(multiprocessing.get_context())): + Process = NoDaemonProcess + + +class NoDaemonPool(multiprocessing.pool.Pool): + """A no daemon pool version of multiprocessing.Pool. + This allows us to start new processings inside the worker function""" + + def __init__(self, *args, **kwargs): + kwargs['context'] = NoDaemonContext() + super().__init__(*args, **kwargs) + + +def kill_child_processes(parent_pid, sig=signal.SIGTERM): + """kill all child processes recursively""" + try: + parent = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + children = parent.children(recursive=True) + for process in children: + try: + process.send_signal(sig) + except psutil.NoSuchProcess: + return + + +def call_func_with_timeout(timeout, func, args=(), kwargs=None): + """Call a function with timeout""" + def func_wrapper(que): + if kwargs: + que.put(func(*args, **kwargs)) + else: + que.put(func(*args)) + + que = multiprocessing.Queue(2) + process = multiprocessing.Process(target=func_wrapper, args=(que,)) + process.start() + process.join(timeout) + + try: + res = que.get(block=False) + except queue.Empty: + res = TimeoutError() + + # clean queue and process + kill_child_processes(process.pid) + process.terminate() + process.join() + que.close() + que.join_thread() + del process + del que + + return res + + +def request_remote(device_key, host=None, port=None, priority=1, timeout=60): + """Request a remote session + + Parameters + ---------- + device_key: string + The device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this session (units: second) + + Returns + ------ + session: RPCSession + """ + # connect to the tracker + host = host or os.environ['TVM_TRACKER_HOST'] + port = port or int(os.environ['TVM_TRACKER_PORT']) + + tracker = _rpc.connect_tracker(host, port) + remote = tracker.request(device_key, priority=priority, + session_timeout=timeout) + return remote + + +def check_remote(device_key, host=None, port=None, priority=100, timeout=10): + """ + Check the availability of a remote device + + Parameters + ---------- + device_key: string + device key of registered device in tracker + host: host, optional + The host address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_HOST" + port: int, optional + The port address of rpc tracker. + If is none, will use environment variable "TVM_TRACKER_PORT" + priority: int, optional + The priority of this request, larger is more prior + timeout: float, optional + The timeout of this check (units: seconds). + + Returns + ------- + available: bool + True if can find available device + """ + + def _check(): + remote = request_remote(device_key, host, port, priority) + + t = threading.Thread(target=_check, ) + t.start() + t.join(timeout) + return not t.is_alive() diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 1e33068e4965..c9415a70c303 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -588,15 +588,6 @@ ComputeDAG ComputeDAGNode::make_by_workload_key(const std::string& workload_key) return ComputeDAGNode::make(std::move(tens)); } -void ComputeDAGNode::VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("tensors", &tensors); - v->Visit("ops", &ops); - v->Visit("flop_ct", &flop_ct); - v->Visit("access_analyzer", &access_analyzer); - State s = Downcast(init_state); - v->Visit("init_state", &s); -} - // Implemented in multi_stage_policy.cc // Extract primitive iterators from a nested fused or splitted iterator's name extern void ExtractOriginalIterators(const std::string& name, std::set* rets); @@ -1166,9 +1157,6 @@ std::pair > ComputeDAG::ReplaySteps( return std::make_pair(schedule, operator->()->tensors); } -TVM_REGISTER_GLOBAL("ansor.ComputeDAG") -.set_body_typed([](Array tensors) { return ComputeDAGNode::make(tensors); }); - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { auto* node = static_cast(ref.get()); @@ -1262,5 +1250,26 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } }); +TVM_REGISTER_GLOBAL("ansor.ComputeDAG") +.set_body_typed([](Array tensors) { + return ComputeDAGNode::make(tensors); +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGGetInitState") +.set_body_method(&ComputeDAG::GetInitState); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + te::Schedule sch; + Array return_tensors; + std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps); + return Array{sch, return_tensors}; +}); + +TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState") +.set_body_typed([](const ComputeDAG& dag, const State& state) { + return dag.PrintStepsAsPython(state->transform_steps); +}); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/compute_dag.h b/src/ansor/compute_dag.h index 9d0708a77f1c..3b4c80c50ad8 100644 --- a/src/ansor/compute_dag.h +++ b/src/ansor/compute_dag.h @@ -93,7 +93,13 @@ class ComputeDAGNode : public Object { AccessAnalyzer access_analyzer; // Read/Write accesss static analyzer ObjectRef init_state; // initial states - void VisitAttrs(tvm::AttrVisitor* v); + void VisitAttrs(tvm::AttrVisitor* v) { + LOG(INFO) << "ComputeDAG"; + v->Visit("tensors", &tensors); + v->Visit("ops", &ops); + v->Visit("flop_ct", &flop_ct); + v->Visit("access_analyzer", &access_analyzer); + } static ComputeDAG make(Array tensors); static ComputeDAG make_by_workload_key(const std::string& workload_key); diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index ebea5a1e472a..e18d36e34581 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -2,8 +2,10 @@ * Copyright (c) 2020 by Contributors */ #include "loop_state.h" -#include + #include +#include + #include "utils.h" namespace tvm { @@ -16,15 +18,15 @@ Stage StageNode::make(te::Operation op) { auto node = make_object(); if (op->IsInstance()) { node->op_type = kCompute; - auto *pop = op.as(); + auto* pop = op.as(); for (const auto& axis : pop->axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kSpace, kNone)); + axis->dom, kSpace, kNone)); } for (const auto& axis : pop->reduce_axis) { node->iters.push_back(IteratorNode::make(CleanName(axis->var->name_hint), - axis->dom, kReduce, kNone)); + axis->dom, kReduce, kNone)); } } else if (op->IsInstance()) { node->op_type = kPlaceholder; @@ -54,9 +56,8 @@ Stage StageNode::make(te::Operation op, StageType op_type, } Stage StageNode::make(te::Operation op, StageType op_type, - std::vector&& iters, - ComputeAtType compute_at, int16_t auto_unroll_max_step, - int storage_offset) { + std::vector&& iters, ComputeAtType compute_at, + int16_t auto_unroll_max_step, int storage_offset) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; @@ -67,16 +68,6 @@ Stage StageNode::make(te::Operation op, StageType op_type, return Stage(node); } -TVM_REGISTER_GLOBAL("ansor.StageGetIterator") - .set_body_typed([](const Stage& stage, int index) { - return stage->iters[index]; - }); - -TVM_REGISTER_GLOBAL("ansor.StageGetIterators") - .set_body_typed([](const Stage& stage) { - return Array(stage->iters); - }); - State StateNode::make_empty_state() { auto node = make_object(); node->attach_map = AttachMapNode::make(); @@ -97,8 +88,8 @@ State StateNode::make(const Array& ops) { } State StateNode::make(const std::vector& stages, - const std::vector& transform_steps, - bool complete, ObjectRef aux_info) { + const std::vector& transform_steps, bool complete, + ObjectRef aux_info) { auto node = make_object(); node->stages = stages; node->transform_steps = transform_steps; @@ -131,31 +122,32 @@ std::vector State::split(int stage_id, const Iterator& it, bool inner_to_outer) { const Stage& stage = operator->()->stages[stage_id]; - SplitStep step = SplitStepNode::make(stage_id, GetIndex(stage->iters, it), - it->range.defined() ? it->range->extent : PrimExpr(), lengths, - inner_to_outer); + SplitStep step = + SplitStepNode::make(stage_id, GetIndex(stage->iters, it), + it->range.defined() ? it->range->extent : PrimExpr(), + lengths, inner_to_outer); CopyOnWrite()->transform_steps.push_back(step); return DoSplitStep(step); } -std::vector State::follow_split(int stage_id, - const Iterator& it, int src_step_id, int n_split) { +std::vector State::follow_split(int stage_id, const Iterator& it, + int src_step_id, int n_split) { const Stage& stage = operator->()->stages[stage_id]; - FollowSplitStep step = FollowSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_id, n_split); + FollowSplitStep step = FollowSplitStepNode::make( + stage_id, GetIndex(stage->iters, it), src_step_id, n_split); CopyOnWrite()->transform_steps.push_back(step); return DoFollowSplitStep(step); } - std::vector State::follow_fused_split( int stage_id, const Iterator& it, const std::vector& src_step_ids, int level, bool factor_or_nparts) { const Stage& stage = operator->()->stages[stage_id]; - FollowFusedSplitStep step = FollowFusedSplitStepNode::make(stage_id, - GetIndex(stage->iters, it), src_step_ids, level, factor_or_nparts); + FollowFusedSplitStep step = + FollowFusedSplitStepNode::make(stage_id, GetIndex(stage->iters, it), + src_step_ids, level, factor_or_nparts); CopyOnWrite()->transform_steps.push_back(step); return DoFollowFusedSplitStep(step); } @@ -179,16 +171,16 @@ Iterator State::vectorize(int stage_id, const Iterator& it) { Iterator State::parallel(int stage_id, const Iterator& it) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make( - stage_id, GetIndex(stage->iters, it), kParallel); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kParallel); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { const Stage& stage = operator->()->stages[stage_id]; - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), kUnroll); + AnnotationStep step = + AnnotationStepNode::make(stage_id, GetIndex(stage->iters, it), kUnroll); // don't unroll if the extent is larger than max_unroll if (max_unroll != -1 && it->range.defined()) { @@ -206,8 +198,8 @@ Iterator State::unroll(int stage_id, const Iterator& it, int max_unroll) { void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) { const Stage& target_stage = operator->()->stages[target_stage_id]; - ComputeAtStep step = ComputeAtStepNode::make(stage_id, target_stage_id, - GetIndex(target_stage->iters, target_iter)); + ComputeAtStep step = ComputeAtStepNode::make( + stage_id, target_stage_id, GetIndex(target_stage->iters, target_iter)); CopyOnWrite()->transform_steps.push_back(step); return DoComputeAtStep(step); } @@ -227,8 +219,8 @@ void State::compute_inline(int stage_id) { void State::pack_for_vec(int stage_id, const Iterator& target_iter, int vec_size) { const Stage& stage = operator->()->stages[stage_id]; - PackForVecStep step = PackForVecStepNode::make(stage_id, - GetIndex(stage->iters, target_iter), vec_size); + PackForVecStep step = PackForVecStepNode::make( + stage_id, GetIndex(stage->iters, target_iter), vec_size); CopyOnWrite()->transform_steps.push_back(step); return DoPackForVecStep(step); } @@ -240,8 +232,8 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, LOG(FATAL) << "thread_type error, valide: kVThread, kBlockX, kThreadX, " << "kThreadY"; } - AnnotationStep step = AnnotationStepNode::make(stage_id, - GetIndex(stage->iters, it), thread_type); + AnnotationStep step = AnnotationStepNode::make( + stage_id, GetIndex(stage->iters, it), thread_type); CopyOnWrite()->transform_steps.push_back(step); return DoAnnotationStep(step); } @@ -249,14 +241,14 @@ Iterator State::bind_thread(int stage_id, const Iterator& it, int State::cache_read(int stage_id, const std::string& scope_name, const std::vector& reader_stage_ids, const ComputeDAG& task_dag) { - CacheReadStep step = CacheReadStepNode::make(stage_id, scope_name, - reader_stage_ids); + CacheReadStep step = + CacheReadStepNode::make(stage_id, scope_name, reader_stage_ids); CopyOnWrite()->transform_steps.push_back(step); return DoCacheReadStep(step, task_dag); } int State::cache_write(int stage_id, const std::string& scope_name, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { CacheWriteStep step = CacheWriteStepNode::make(stage_id, scope_name); CopyOnWrite()->transform_steps.push_back(step); return DoCacheWriteStep(step, task_dag); @@ -265,14 +257,14 @@ int State::cache_write(int stage_id, const std::string& scope_name, void State::pragma(int stage_id, const Iterator& it, const std::string& pragma_type) { const Stage& stage = operator->()->stages[stage_id]; - PragmaStep step = PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), - pragma_type); + PragmaStep step = + PragmaStepNode::make(stage_id, GetIndex(stage->iters, it), pragma_type); CopyOnWrite()->transform_steps.push_back(step); return DoPragmaStep(step); } int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, - const ComputeDAG& task_dag) { + const ComputeDAG& task_dag) { const Stage& stage = operator->()->stages[stage_id]; RfactorStep step = RfactorStepNode::make(stage_id, GetIndex(stage->iters, it), factor_iter_id); @@ -283,8 +275,8 @@ int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) { const Stage& stage = operator->()->stages[stage_id]; - StorageAlignStep step = StorageAlignStepNode::make(stage_id, - GetIndex(stage->iters, it), factor, offset); + StorageAlignStep step = StorageAlignStepNode::make( + stage_id, GetIndex(stage->iters, it), factor, offset); CopyOnWrite()->transform_steps.push_back(step); return DoStorageAlignStep(step); } @@ -299,11 +291,9 @@ void State::DoReorderStep(const ReorderStep& step) { } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(iters), - stage->compute_at, - stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -324,7 +314,8 @@ std::vector State::DoSplitStepCommon( std::vector outs; for (size_t i = 0; i < lengths.size(); ++i) { - PrimExpr l; std::string name; + PrimExpr l; + std::string name; if (inner_to_outer) { l = lengths[lengths.size() - i - 1]; name = it->name + "." + std::to_string(lengths.size() - i); @@ -350,26 +341,26 @@ std::vector State::DoSplitStepCommon( range = Range::make_by_min_extent(tosplit_min, tosplit_extent); } if (inner_to_outer) { - outs.push_back(IteratorNode::make(it->name + ".0", range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + ".0", range, it->iter_type, kNone)); std::reverse(outs.begin(), outs.end()); } else { - outs.push_back(IteratorNode::make( - it->name + "." + std::to_string(lengths.size()), range, it->iter_type, - kNone)); + outs.push_back( + IteratorNode::make(it->name + "." + std::to_string(lengths.size()), + range, it->iter_type, kNone)); } std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id); new_iters.insert(new_iters.end(), outs.begin(), outs.end()); - new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id+1, + new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -396,8 +387,8 @@ std::vector State::DoFollowSplitStep(const FollowSplitStep& step) { std::vector State::DoFollowFusedSplitStep( const FollowFusedSplitStep& step) { - const PrimExpr& length = step->ExtractSplitLength( - operator->()->transform_steps); + const PrimExpr& length = + step->ExtractSplitLength(operator->()->transform_steps); return DoSplitStepCommon(step->stage_id, step->iter_id, {length}, step->factor_or_nparts); } @@ -414,15 +405,14 @@ Iterator State::DoFuseStep(const FuseStep& step) { std::vector ori_iters; for (size_t i = 0; i < step->fused_ids.size(); ++i) { if (i > 0) { - CHECK_EQ(step->fused_ids[i], step->fused_ids[i-1] + 1); + CHECK_EQ(step->fused_ids[i], step->fused_ids[i - 1] + 1); } if (i != step->fused_ids.size() - 1) { const auto& iter_to_attached_stage = - operator->()->attach_map->iter_to_attached_stages; - if (iter_to_attached_stage.find(std::make_pair(stage_id, - step->fused_ids[i])) - != iter_to_attached_stage.end()) { + operator->()->attach_map->iter_to_attached_stages; + if (iter_to_attached_stage.find(std::make_pair( + stage_id, step->fused_ids[i])) != iter_to_attached_stage.end()) { LOG(FATAL) << "Invalid Fuse. Because you want to fuse iterators " "that have been attached by some stages"; } @@ -451,8 +441,8 @@ Iterator State::DoFuseStep(const FuseStep& step) { if (new_extent.defined()) { range = Range::make_by_min_extent(0, new_extent); } - Iterator new_it = IteratorNode::make(new_name, range, new_iter_type, kNone, - &ori_iters); + Iterator new_it = + IteratorNode::make(new_name, range, new_iter_type, kNone, &ori_iters); std::vector new_iters; new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + step->fused_ids.front()); @@ -462,9 +452,9 @@ Iterator State::DoFuseStep(const FuseStep& step) { stage->iters.end()); StateNode* pstate = CopyOnWrite(); - pstate->stages[stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[stage_id] = StageNode::make( + stage->op, stage->op_type, std::move(new_iters), stage->compute_at, + stage->auto_unroll_max_step, stage->storage_offset); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -477,7 +467,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { } else if (i > end_id) { // move forward from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, i - end_id + begin_id); - } else { // move to the fused id + } else { // move to the fused id from_iters.emplace_back(stage_id, i); to_iters.emplace_back(stage_id, begin_id); } @@ -491,7 +481,7 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) { Iterator it = stage->iters[step->iter_id]; Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type, - step->annotation, &it->ori_iters); + step->annotation, &it->ori_iters); Stage new_stage = stage; new_stage.CopyOnWrite()->iters[step->iter_id] = new_it; StateNode* pstate = CopyOnWrite(); @@ -508,8 +498,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { size_t s = it->name.size(); - if (s >= 2 && it->name[s-2] == '.' && it->name[s-1] >= '1' && - it->name[s-1] <= '4') { + if (s >= 2 && it->name[s - 2] == '.' && it->name[s - 1] >= '1' && + it->name[s - 1] <= '4') { // We use a dangerous heuristic rule here : For multi level splitted // iterators, we assume their length does not change after compute_at. // Reason: These iterators are generated in MultiStagePolicy by multi @@ -519,14 +509,14 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { new_iters.push_back(it); } else { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } } StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kIter, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -540,14 +530,14 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { std::vector new_iters; for (const Iterator& it : stage->iters) { new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type, - it->annotation, &it->ori_iters)); + it->annotation, &it->ori_iters)); } // update attach map StateNode* pstate = CopyOnWrite(); - pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), kRoot, stage->auto_unroll_max_step, - stage->storage_offset); + pstate->stages[step->stage_id] = + StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, + stage->auto_unroll_max_step, stage->storage_offset); pstate->attach_map.DeleteStage(step->stage_id); } @@ -560,9 +550,10 @@ void State::DoComputeInlineStep(const ComputeInlineStep& step) { const auto& iter_to_attached_stages = pstate->attach_map->iter_to_attached_stages; for (size_t i = 0; i < stage->iters.size(); ++i) { - CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), 0) - << "Invalid compute_inline: Because there are some other stages " - "that are attached to the target stage"; + CHECK_EQ(iter_to_attached_stages.count(std::make_pair(step->stage_id, i)), + 0) + << "Invalid compute_inline: Because there are some other stages " + "that are attached to the target stage"; } pstate->stages[step->stage_id].CopyOnWrite()->compute_at = kInlined; @@ -576,7 +567,8 @@ void State::DoPackForVecStep(const PackForVecStep& step) { // Common part for steps that add new stages // (e.g. CacheReadStep, CacheWriteStep, RfactorStep) void AddStageModificationSteps(size_t step_id, - const std::vector& transform_steps, std::vector* replay_steps) { + const std::vector& transform_steps, + std::vector* replay_steps) { const Step& step = transform_steps[step_id]; if (step->IsInstance() || step->IsInstance()) { @@ -615,14 +607,15 @@ int State::DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag) { // target -> target + target_store // Should update target's op, insert new stage, update the later stage's op pstate->stages[step->stage_id].CopyOnWrite()->op = - operator->()->task_dag->ops[step->stage_id]; - pstate->stages.insert(pstate->stages.begin() + step->stage_id + 1, + operator->()->task_dag->ops[step->stage_id]; + pstate->stages.insert( + pstate->stages.begin() + step->stage_id + 1, StageNode::make(operator->()->task_dag->ops[step->stage_id + 1])); for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id + 1, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id + 1, 1); return step->stage_id + 1; } @@ -637,8 +630,9 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { } } - int last_dag_op_size = pstate->task_dag.defined() ? - pstate->task_dag->ops.size() : dag->ops.size(); + int last_dag_op_size = pstate->task_dag.defined() + ? pstate->task_dag->ops.size() + : dag->ops.size(); dag.ReplayAndGetDAG(replay_steps, &(pstate->task_dag)); int added_ops = pstate->task_dag->ops.size() - last_dag_op_size; CHECK_GE(added_ops, 1); @@ -646,7 +640,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Assume target stage has never been applied any steps before cache_write // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); pstate->stages[step->stage_id + 1] = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); @@ -657,7 +652,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { // for more information // TODO(jcf94): Fix this if (added_ops == 2) { - pstate->stages.insert(pstate->stages.begin() + next_stage_id, + pstate->stages.insert( + pstate->stages.begin() + next_stage_id, StageNode::make(operator->()->task_dag->ops[next_stage_id])); next_stage_id++; } else if (added_ops > 2) { @@ -666,8 +662,8 @@ int State::DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag) { for (size_t i = next_stage_id; i < operator->()->task_dag->ops.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, added_ops); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, added_ops); return step->stage_id; } @@ -702,18 +698,20 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { // target -> target_compute + target // Should insert new stage, update target stage, update the later stage's op - pstate->stages.insert(pstate->stages.begin() + step->stage_id, + pstate->stages.insert( + pstate->stages.begin() + step->stage_id, StageNode::make(operator->()->task_dag->ops[step->stage_id])); // maintain the compute_at type of target stage - Stage target_stage = StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); + Stage target_stage = + StageNode::make(operator->()->task_dag->ops[step->stage_id + 1]); target_stage.CopyOnWrite()->compute_at = compute_at_type; pstate->stages[step->stage_id + 1] = target_stage; for (size_t i = step->stage_id + 2; i < operator->()->stages.size(); ++i) { pstate->stages[i].CopyOnWrite()->op = operator->()->task_dag->ops[i]; } - pstate->attach_map = - operator->()->attach_map.ApplyStageIdOfffset(step->stage_id, 1); + pstate->attach_map = operator->()->attach_map.ApplyStageIdOfffset( + step->stage_id, 1); return step->stage_id; } @@ -777,7 +775,6 @@ void State::DoSteps(const std::vector& steps, const ComputeDAG& dag) { } } - void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; @@ -786,15 +783,15 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " auto_unroll: " - << stage->auto_unroll_max_step << "\n"; + *os << stage->op->func_name() + << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; } if (stage->storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } - *os << stage->op->func_name() << " storage_offset: " - << stage->storage_offset << "\n"; + *os << stage->op->func_name() + << " storage_offset: " << stage->storage_offset << "\n"; } size_t indent = 0; @@ -802,26 +799,46 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, const Iterator& iter = stage->iters[i]; if (!(delete_trivial_loop && iter->range.defined() && - is_one(iter->range->extent))) { + is_one(iter->range->extent))) { for (size_t j = 0; j < base_indent + indent; ++j) { *os << " "; } switch (iter->annotation) { - case kNone: *os << "for "; break; - case kUnroll: *os << "unroll "; break; - case kParallel: *os << "parallel "; break; - case kVectorize: *os << "vectorize "; break; - case kVThread: *os << "vthread "; break; - case kBlockX: *os << "gpu.blockIdx.x "; break; - case kBlockY: *os << "gpu.blockIdx.y "; break; - case kThreadX: *os << "gpu.threadIdx.x "; break; - case kThreadY: *os << "gpu.threadIdx.y "; break; + case kNone: + *os << "for "; + break; + case kUnroll: + *os << "unroll "; + break; + case kParallel: + *os << "parallel "; + break; + case kVectorize: + *os << "vectorize "; + break; + case kVThread: + *os << "vthread "; + break; + case kBlockX: + *os << "gpu.blockIdx.x "; + break; + case kBlockY: + *os << "gpu.blockIdx.y "; + break; + case kThreadX: + *os << "gpu.threadIdx.x "; + break; + case kThreadY: + *os << "gpu.threadIdx.y "; + break; } if (iter->range.defined()) { *os << iter->name << " (" << iter->range->min << "," - << iter->range->extent << ")" << "\n"; + << iter->range->extent << ")" + << "\n"; } else { - *os << iter->name << " (None)" << "\n"; + *os << iter->name << " (None)" + << "\n"; } indent += 2; @@ -885,6 +902,110 @@ std::string State::ToStr(bool delete_trivial_loop) const { return os.str(); } +void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, + int target_iter_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the current entry of stage + DeleteStageEntry(pnode, stage_id); + + // store the new relation + IterKey iter_key(target_stage_id, target_iter_id); + pnode->stage_to_attach_iter[stage_id] = + std::make_pair(target_stage_id, target_iter_id); + pnode->iter_to_attached_stages[iter_key].push_back(stage_id); +} + +void AttachMap::DeleteStage(int stage_id) { + AttachMapNode* pnode = CopyOnWrite(); + + // delete the entry of old stage + DeleteStageEntry(pnode, stage_id); +} + +void AttachMap::ReplaceIters(const std::vector& old_iters, + const std::vector& new_iters) { + AttachMapNode* pnode = CopyOnWrite(); + + CHECK_EQ(old_iters.size(), new_iters.size()); + for (size_t i = 0; i < old_iters.size(); ++i) { + auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); + if (entry == pnode->iter_to_attached_stages.end()) { + continue; + } + + // replace iter in the value of `stage_to_attach_iter` + for (const auto& s : entry->second) { + pnode->stage_to_attach_iter[s] = new_iters[i]; + } + + // replace iter in the key of `iter_to_attached_stages` + std::vector attached_stages = std::move(entry->second); + pnode->iter_to_attached_stages.erase(entry); + pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); + } +} + +void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) { + auto old_entry = pnode->stage_to_attach_iter.find(stage_id); + if (old_entry != pnode->stage_to_attach_iter.end()) { + // delete value in `iter_to_attached_stages` + auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); + DeleteItem(&entry2->second, stage_id); + if (entry2->second.size() == 0) { + pnode->iter_to_attached_stages.erase(entry2); + } + // delete key in `stage_to_attach_iter` + pnode->stage_to_attach_iter.erase(old_entry); + } +} + +AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { + AttachMap map = AttachMapNode::make(); + auto pmap = map.CopyOnWrite(); + for (const auto& x : operator->()->stage_to_attach_iter) { + auto key = x.first; + if (key >= start_id) { + key += offset; + } + auto value = x.second; + if (value.first >= start_id) { + value.first += offset; + } + pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); + } + for (const auto& x : operator->()->iter_to_attached_stages) { + auto key = x.first; + if (key.first >= start_id) { + key.first += offset; + } + auto value = x.second; + for (auto& i : value) { + if (i >= start_id) { + i += offset; + } + } + pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); + } + return map; +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + PrintState(&p->stream, node, true); + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterator") + .set_body_typed([](const Stage& stage, int index) { + return stage->iters[index]; + }); + +TVM_REGISTER_GLOBAL("ansor.StageGetIterators") + .set_body_typed([](const Stage& stage) { + return Array(stage->iters); + }); + TVM_REGISTER_GLOBAL("ansor.StateGetStage") .set_body_typed([](const State& state, int index) { return state->stages[index]; @@ -908,21 +1029,20 @@ TVM_REGISTER_GLOBAL("ansor.StateReorder") TVM_REGISTER_GLOBAL("ansor.StateSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, - const Array& lengths, - bool inner_to_outer) { + const Array& lengths, bool inner_to_outer) { std::vector len; for (const auto& i : lengths) { len.push_back(i); } - state.split(stage_id, it, len, inner_to_outer); - return state; + const auto& res = state.split(stage_id, it, len, inner_to_outer); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowSplit") .set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id, int n_split) { - state.follow_split(stage_id, it, src_step_id, n_split); - return state; + const auto& res = state.follow_split(stage_id, it, src_step_id, n_split); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") @@ -933,9 +1053,9 @@ TVM_REGISTER_GLOBAL("ansor.StateFollowFusedSplit") for (const auto& i : src_step_ids) { array_src_step_ids.push_back(i->value); } - state.follow_fused_split(stage_id, it, array_src_step_ids, level, - factor_or_nparts); - return state; + const auto& res = state.follow_fused_split( + stage_id, it, array_src_step_ids, level, factor_or_nparts); + return Array{state, Array(res)}; }); TVM_REGISTER_GLOBAL("ansor.StateFuse") @@ -945,36 +1065,35 @@ TVM_REGISTER_GLOBAL("ansor.StateFuse") for (const auto& i : iters) { its.push_back(i); } - state.fuse(stage_id, its); - return state; + const auto& res = state.fuse(stage_id, its); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateVectorize") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.vectorize(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.vectorize(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateParallel") - .set_body_typed([](State state, int stage_id, - const Iterator& it) { - state.parallel(stage_id, it); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it) { + const auto& res = state.parallel(stage_id, it); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateUnroll") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int max_unroll) { - state.unroll(stage_id, it, max_unroll); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int max_unroll) { + const auto& res = state.unroll(stage_id, it, max_unroll); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateBindThread") - .set_body_typed([](State state, int stage_id, - const Iterator& it, int thread_type) { - state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); - return state; + .set_body_typed([](State state, int stage_id, const Iterator& it, + int thread_type) { + const auto& res = + state.bind_thread(stage_id, it, IteratorAnnotation(thread_type)); + return Array{state, res}; }); TVM_REGISTER_GLOBAL("ansor.StateComputeAt") @@ -997,8 +1116,8 @@ TVM_REGISTER_GLOBAL("ansor.StateComputeInline") }); TVM_REGISTER_GLOBAL("ansor.StatePackForVec") - .set_body_typed([](State state, int stage_id, - const Iterator& target_iter, int vec_size) { + .set_body_typed([](State state, int stage_id, const Iterator& target_iter, + int vec_size) { state.pack_for_vec(stage_id, target_iter, vec_size); return state; }); @@ -1011,110 +1130,17 @@ TVM_REGISTER_GLOBAL("ansor.StateCacheRead") for (const auto& i : reader_stage_ids) { array_reader_stage_ids.push_back(i->value); } - state.cache_read(stage_id, scope_name, array_reader_stage_ids, task_dag); - return state; + int res = state.cache_read(stage_id, scope_name, array_reader_stage_ids, + task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); TVM_REGISTER_GLOBAL("ansor.StateCacheWrite") .set_body_typed([](State state, int stage_id, const std::string& scope_name, const ComputeDAG& task_dag) { - state.cache_write(stage_id, scope_name, task_dag); - return state; + int res = state.cache_write(stage_id, scope_name, task_dag); + return Array{state, IntImm(DataType::Int(32), res)}; }); -void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, - int target_iter_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the current entry of stage - DeleteStageEntry(pnode, stage_id); - - // store the new relation - IterKey iter_key(target_stage_id, target_iter_id); - pnode->stage_to_attach_iter[stage_id] = std::make_pair(target_stage_id, - target_iter_id); - pnode->iter_to_attached_stages[iter_key].push_back(stage_id); -} - -void AttachMap::DeleteStage(int stage_id) { - AttachMapNode* pnode = CopyOnWrite(); - - // delete the entry of old stage - DeleteStageEntry(pnode, stage_id); -} - -void AttachMap::ReplaceIters(const std::vector& old_iters, - const std::vector& new_iters) { - AttachMapNode* pnode = CopyOnWrite(); - - CHECK_EQ(old_iters.size(), new_iters.size()); - for (size_t i = 0; i < old_iters.size(); ++i) { - auto entry = pnode->iter_to_attached_stages.find(old_iters[i]); - if (entry == pnode->iter_to_attached_stages.end()) { - continue; - } - - // replace iter in the value of `stage_to_attach_iter` - for (const auto& s : entry->second) { - pnode->stage_to_attach_iter[s] = new_iters[i]; - } - - // replace iter in the key of `iter_to_attached_stages` - std::vector attached_stages = std::move(entry->second); - pnode->iter_to_attached_stages.erase(entry); - pnode->iter_to_attached_stages[new_iters[i]] = std::move(attached_stages); - } -} - -void AttachMap::DeleteStageEntry(AttachMapNode *pnode, int stage_id) { - auto old_entry = pnode->stage_to_attach_iter.find(stage_id); - if (old_entry != pnode->stage_to_attach_iter.end()) { - // delete value in `iter_to_attached_stages` - auto entry2 = pnode->iter_to_attached_stages.find(old_entry->second); - DeleteItem(&entry2->second, stage_id); - if (entry2->second.size() == 0) { - pnode->iter_to_attached_stages.erase(entry2); - } - // delete key in `stage_to_attach_iter` - pnode->stage_to_attach_iter.erase(old_entry); - } -} - -AttachMap AttachMap::ApplyStageIdOfffset(int start_id, int offset) const { - AttachMap map = AttachMapNode::make(); - auto pmap = map.CopyOnWrite(); - for (const auto& x : operator->()->stage_to_attach_iter) { - auto key = x.first; - if (key >= start_id) { - key += offset; - } - auto value = x.second; - if (value.first >= start_id) { - value.first += offset; - } - pmap->stage_to_attach_iter.insert(std::make_pair(key, value)); - } - for (const auto& x : operator->()->iter_to_attached_stages) { - auto key = x.first; - if (key.first >= start_id) { - key.first += offset; - } - auto value = x.second; - for (auto& i : value) { - if (i >= start_id) { - i += offset; - } - } - pmap->iter_to_attached_stages.insert(std::make_pair(key, value)); - } - return map; -} - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - PrintState(&p->stream, node, true); -}); - } // namespace ansor } // namespace tvm diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index 1bae02b3f2c5..b2cff24973bc 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -3,12 +3,13 @@ */ #include "measure.h" // #include -#include #include +#include + +#include #include #include #include -#include // #include "search_policy/search_policy.h" namespace tvm { @@ -25,16 +26,16 @@ TVM_REGISTER_OBJECT_TYPE(RPCRunnerNode); TVM_REGISTER_OBJECT_TYPE(LocalRunnerNode); TVM_REGISTER_OBJECT_TYPE(ProgramMeasurerNode); -const char *ErrorNoToStr[] = { - "NoError", - "InstantiationError", - "CompileHostError", - "CompileDeviceError", - "RuntimeDeviceError", - "WrongAnswerError", - "BuildTimeoutError", - "RunTimeoutError", - "UnknownError", +const char* ErrorNoToStr[] = { + "NoError", + "InstantiationError", + "CompileHostError", + "CompileDeviceError", + "RuntimeDeviceError", + "WrongAnswerError", + "BuildTimeoutError", + "RunTimeoutError", + "UnknownError", }; // Maker @@ -52,8 +53,9 @@ MeasureInput MeasureInputNode::copy() const { return MeasureInput(node); } -BuildResult BuildResultNode::make(std::string filename, Array args, int error_no, - std::string error_msg, double time_cost) { +BuildResult BuildResultNode::make(std::string filename, Array args, + int error_no, std::string error_msg, + double time_cost) { auto node = make_object(); node->filename = std::move(filename); node->args = std::move(args); @@ -64,7 +66,8 @@ BuildResult BuildResultNode::make(std::string filename, Array args, } MeasureResult MeasureResultNode::make(Array costs, int error_no, - std::string error_msg, double all_cost, double timestamp) { + std::string error_msg, double all_cost, + double timestamp) { auto node = make_object(); node->costs = std::move(costs); node->error_no = error_no; @@ -84,7 +87,8 @@ MeasureResult MeasureResultNode::copy() const { return MeasureResult(node); } -Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& build_func) { +Builder LocalBuilderNode::make(int timeout, int n_parallel, + const std::string& build_func) { auto node = make_object(); node->timeout = timeout; node->n_parallel = n_parallel; @@ -93,9 +97,11 @@ Builder LocalBuilderNode::make(int timeout, int n_parallel, const std::string& b } // LocalBuilder and LocalRunner -Array LocalBuilderNode::Build(const Array &inputs, int verbose) { +Array LocalBuilderNode::Build(const Array& inputs, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_builder.build")) { - Array results = (*f)(inputs, timeout, n_parallel, build_func, verbose); + Array results = + (*f)(inputs, timeout, n_parallel, build_func, verbose); return results; } else { LOG(FATAL) << "ansor.local_builder.build is not registered"; @@ -103,9 +109,10 @@ Array LocalBuilderNode::Build(const Array &inputs, in return Array(); } -Runner RPCRunnerNode::make(const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { +Runner RPCRunnerNode::make(const std::string& key, const std::string& host, + int port, int priority, int timeout, int n_parallel, + int number, int repeat, int min_repeat_ms, + double cooldown_interval) { auto node = make_object(); node->key = key; node->host = host; @@ -124,9 +131,9 @@ Array RPCRunnerNode::Run(const Array& inputs, const Array& build_results, int verbose) { if (const auto* f = runtime::Registry::Get("ansor.rpc_runner.run")) { - Array results = (*f)(inputs, build_results, key, host, port, priority, - timeout, n_parallel, number, repeat, - min_repeat_ms, cooldown_interval, verbose); + Array results = (*f)( + inputs, build_results, key, host, port, priority, timeout, n_parallel, + number, repeat, min_repeat_ms, cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.rpc_runner.run is not registered"; @@ -145,12 +152,13 @@ Runner LocalRunnerNode::make(int timeout, int number, int repeat, return Runner(node); } -Array LocalRunnerNode::Run(const Array& inputs, - const Array& build_results, - int verbose) { +Array LocalRunnerNode::Run( + const Array& inputs, const Array& build_results, + int verbose) { if (const auto* f = runtime::Registry::Get("ansor.local_runner.run")) { - Array results = (*f)(inputs, build_results, timeout, number, - repeat, min_repeat_ms, cooldown_interval, verbose); + Array results = + (*f)(inputs, build_results, timeout, number, repeat, min_repeat_ms, + cooldown_interval, verbose); return results; } else { LOG(FATAL) << "ansor.local_runner.run is not registered"; @@ -167,8 +175,9 @@ ProgramMeasurer ProgramMeasurerNode::make(Builder builder, Runner runner, node->runner = std::move(runner); node->callbacks = std::move(callbacks); node->verbose = verbose; - node->max_continous_error = max_continous_error < 0 ? - DEFAULT_MAX_CONTINOUS_ERROR : max_continous_error; + node->max_continous_error = max_continous_error < 0 + ? DEFAULT_MAX_CONTINOUS_ERROR + : max_continous_error; return ProgramMeasurer(node); } @@ -192,12 +201,14 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, batch_size = builder->n_parallel * 2; } - StdCout(verbose) << "Get " << inputs.size() << " programs for measure. (This may take a while)" - << std::endl; + StdCout(verbose) << "Get " << inputs.size() + << " programs for measure. (This may take a while)" + << std::endl; for (size_t i = 0; i < inputs.size(); i += batch_size) { - std::vector input_batch(inputs.begin() + i, - inputs.begin() + std::min(i + batch_size, inputs.size())); + std::vector input_batch( + inputs.begin() + i, + inputs.begin() + std::min(i + batch_size, inputs.size())); std::vector result_batch; // build and run @@ -207,7 +218,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, for (size_t j = 0; j < input_batch.size(); ++j) { double flops; if (result_batch[j]->error_no == 0) { - flops = task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); + flops = + task->compute_dag->flop_ct / FloatArrayMean(result_batch[j]->costs); error_ct = 0; } else { flops = 0.0; @@ -225,8 +237,8 @@ void ProgramMeasurerNode::Measure(const SearchTask& task, if (verbose >= 1) { std::cout << std::fixed << std::setprecision(2); std::cout << "===============================================\n"; - std::cout << "No: " << ct - << "\tGFLOPS: " << flops / 1e9 << " / " << best_flops[workload_key] / 1e9 + std::cout << "No: " << ct << "\tGFLOPS: " << flops / 1e9 << " / " + << best_flops[workload_key] / 1e9 << "\tresults: " << result_batch[j] << "\n"; std::cout << "===============================================\n"; std::cout << input_batch[j]->state << "\n"; @@ -261,7 +273,8 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Call builder and runner Array build_res_batch = builder->Build(input_batch, verbose); - Array result_batch = runner->Run(input_batch, build_res_batch, verbose); + Array result_batch = + runner->Run(input_batch, build_res_batch, verbose); // Store result batch for (auto& res : result_batch) { @@ -271,44 +284,89 @@ void ProgramMeasurerNode::SilentMeasure(const SearchTask& task, // Printing functions TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - p->stream << "MeasureInput()"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + p->stream << "MeasureInput()"; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - if (node->error_no == kNoError) { - p->stream << "MeasureResult(cost:["; - auto old_config = p->stream.precision(4); - for (size_t i = 0; i < node->costs.size(); ++i) { - auto pf = node->costs[i].as(); - CHECK(pf != nullptr); - p->stream << pf->value; - if (i != node->costs.size() - 1) { - p->stream << ","; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + if (node->error_no == kNoError) { + p->stream << "MeasureResult(cost:["; + auto old_config = p->stream.precision(4); + for (size_t i = 0; i < node->costs.size(); ++i) { + auto pf = node->costs[i].as(); + CHECK(pf != nullptr); + p->stream << pf->value; + if (i != node->costs.size() - 1) { + p->stream << ","; + } + } + p->stream.precision(old_config); + p->stream << "], "; + p->stream << "error_no:" << 0 << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; + } else { + p->stream << "MeasureResult(" + << "error_type:" << ErrorNoToStr[node->error_no] << ", " + << "error_msg:" << node->error_msg << ", " + << "all_cost:" << node->all_cost << ", " + << "Tstamp:" << node->timestamp << ")"; } - } - p->stream.precision(old_config); - p->stream << "], "; - p->stream << "error_no:" << 0 << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } else { - p->stream << "MeasureResult(" - << "error_type:" << ErrorNoToStr[node->error_no] << ", " - << "error_msg:" << node->error_msg << ", " - << "all_cost:" << node->all_cost << ", " - << "Tstamp:" << node->timestamp << ")"; - } -}); + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& ref, ReprPrinter *p) { - auto* node = static_cast(ref.get()); - p->stream << "BuildResult(" << node->filename << ", " << node->error_no - << ", " << node->time_cost << ")"; -}); + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "BuildResult(" << node->filename << ", " << node->error_no + << ", " << node->time_cost << ")"; + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureInput") + .set_body_typed([](SearchTask task, State state) { + return MeasureInputNode::make(task, state); + }); + +TVM_REGISTER_GLOBAL("ansor.BuildResult") + .set_body_typed([](std::string filename, Array args, + int error_no, std::string error_msg, double time_cost) { + return BuildResultNode::make(filename, args, error_no, error_msg, + time_cost); + }); + +TVM_REGISTER_GLOBAL("ansor.MeasureResult") + .set_body_typed([](Array costs, int error_no, + std::string error_msg, double all_cost, + double timestamp) { + return MeasureResultNode::make(costs, error_no, error_msg, all_cost, + timestamp); + }); + +TVM_REGISTER_GLOBAL("ansor.BuilderBuild") + .set_body_typed([](const Builder& builder, + const Array& inputs, int verbose) { + return builder->Build(inputs, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.RunnerRun") + .set_body_typed([](const Runner& runner, const Array& inputs, + const Array& build_results, int verbose) { + return runner->Run(inputs, build_results, verbose); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalBuilder") + .set_body_typed([](int timeout, int n_parallel, + const std::string& build_func) { + return LocalBuilderNode::make(timeout, n_parallel, build_func); + }); + +TVM_REGISTER_GLOBAL("ansor.LocalRunner") + .set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, + double cooldown_interval) { + return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, + cooldown_interval); + }); } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.cc b/src/ansor/search_task.cc index b9cda9168b9e..93f3f60ea768 100644 --- a/src/ansor/search_task.cc +++ b/src/ansor/search_task.cc @@ -2,20 +2,23 @@ * Copyright (c) 2020 by Contributors */ #include "search_task.h" -#include -#include + #include -#include +#include +#include + #include +#include namespace tvm { namespace ansor { -TVM_REGISTER_OBJECT_TYPE(HardwareParamsNode); -TVM_REGISTER_OBJECT_TYPE(SearchTaskNode); +TVM_REGISTER_NODE_TYPE(HardwareParamsNode); +TVM_REGISTER_NODE_TYPE(SearchTaskNode); HardwareParams HardwareParamsNode::make(int num_cores, int vector_unit_bytes, - int cache_line_bytes, int max_unroll_vec, + int cache_line_bytes, + int max_unroll_vec, int max_innermost_split_factor) { auto node = make_object(); node->num_cores = num_cores; @@ -40,21 +43,19 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLGPU, 0}; auto func = tvm::runtime::Registry::Get("device_api.gpu"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxRegistersPerBlock, &ret); p_hardware_params->max_registers_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -73,16 +74,15 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( auto ctx = TVMContext{kDLOpenCL, 0}; auto func = tvm::runtime::Registry::Get("device_api.opencl"); CHECK(func != nullptr) << "Cannot find GPU device_api in registry"; - auto device_api = static_cast(((*func)()).operator void*()); + auto device_api = + static_cast(((*func)()).operator void*()); tvm::runtime::TVMRetValue ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, - &ret); + device_api->GetAttr( + ctx, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); p_hardware_params->max_shared_memory_per_block = ret; - device_api->GetAttr(ctx, - tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, + device_api->GetAttr(ctx, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); p_hardware_params->max_threads_per_block = ret; @@ -99,9 +99,10 @@ HardwareParams HardwareParamsNode::GetDefaultHardwareParams( return HardwareParams(); } - -SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key, - Target target, Target target_host, HardwareParams hardware_params) { +SearchTask SearchTaskNode::make(ComputeDAG compute_dag, + std::string workload_key, Target target, + Target target_host, + HardwareParams hardware_params) { auto node = make_object(); node->compute_dag = std::move(compute_dag); node->workload_key = std::move(workload_key); @@ -116,5 +117,22 @@ SearchTask SearchTaskNode::make(ComputeDAG compute_dag, std::string workload_key return SearchTask(node); } +TVM_REGISTER_GLOBAL("ansor.HardwareParams") + .set_body_typed([](int num_cores, int vector_unit_bytes, + int cache_line_bytes, int max_unroll_vec, + int max_innermost_split_factor) { + return HardwareParamsNode::make(num_cores, vector_unit_bytes, + cache_line_bytes, max_unroll_vec, + max_innermost_split_factor); + }); + +TVM_REGISTER_GLOBAL("ansor.SearchTask") + .set_body_typed([](ComputeDAG compute_dag, std::string workload_key, + Target target, Target target_host, + HardwareParams hardware_params) { + return SearchTaskNode::make(compute_dag, workload_key, target, + target_host, hardware_params); + }); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_task.h b/src/ansor/search_task.h index 7db98a5197a5..9512013848b6 100644 --- a/src/ansor/search_task.h +++ b/src/ansor/search_task.h @@ -8,13 +8,16 @@ #define TVM_ANSOR_SEARCH_TASK_H_ #include + #include + #include "compute_dag.h" namespace tvm { namespace ansor { -class HardwareParams; class SearchTask; +class HardwareParams; +class SearchTask; /*! \brief Hardware related parameters */ class HardwareParamsNode : public Object { @@ -54,12 +57,11 @@ class HardwareParamsNode : public Object { static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host); - static constexpr const char *_type_key = "ansor.HardwareParams"; + static constexpr const char* _type_key = "ansor.HardwareParams"; TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object); }; TVM_DEFINE_COW_NODE_REF(HardwareParams, ObjectRef, HardwareParamsNode); - /*! \brief Meta-info for a search task */ class SearchTaskNode : public Object { public: @@ -81,7 +83,7 @@ class SearchTaskNode : public Object { Target target, Target target_host, HardwareParams hardware_params); - static constexpr const char *_type_key = "ansor.SearchTask"; + static constexpr const char* _type_key = "ansor.SearchTask"; TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object); }; TVM_DEFINE_COW_NODE_REF(SearchTask, ObjectRef, SearchTaskNode); diff --git a/tests/cpp/ansor_test.cc b/tests/cpp/ansor_test.cc index 75a6cc00b802..e5a2c98c02a9 100644 --- a/tests/cpp/ansor_test.cc +++ b/tests/cpp/ansor_test.cc @@ -242,15 +242,15 @@ TEST(Step, SplitFuseReorder) { CHECK_EQ(s1->stages[2]->iters[0]->range->extent.as()->value, 512); its = s0.split(2, ti, {16}); + Iterator tio = its[0], tii = its[1]; CHECK_EQ(s0->stages[2]->iters[0]->range->extent.as()->value, 32); CHECK_EQ(s0->stages[2]->iters[1]->range->extent.as()->value, 16); - Iterator tio = its[0], tii = its[1]; its = s0.split(2, tj, {8}); + Iterator tjo = its[0], tji = its[1]; CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 64); CHECK_EQ(s0->stages[2]->iters[3]->range->extent.as()->value, 8); - Iterator tjo = its[0], tji = its[1]; s0.reorder(2, {tio, tjo, tk, tji, tii}); CHECK_EQ(s0->stages[2]->iters[2]->range->extent.as()->value, 512); diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 4782f9130cea..da87ea5fe9cf 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -73,26 +73,26 @@ def test_state_split_fuse_reorder(): assert ti.range.extent == 512 - s0 = s0.split(2, ti, [16]) + s0, its = s0.split(2, ti, [16]) + tio = its[0] + tii = its[1] assert s0.stage(2).iterator(0).range.extent == 32 assert s0.stage(2).iterator(1).range.extent == 16 - tio = s0.stage(2).iterator(0) - tii = s0.stage(2).iterator(1) - s0 = s0.split(2, tj, [8]) + s0, its = s0.split(2, tj, [8]) + tjo = its[0] + tji = its[1] assert s0.stage(2).iterator(2).range.extent == 64 assert s0.stage(2).iterator(3).range.extent == 8 - tjo = s0.stage(2).iterator(2) - tji = s0.stage(2).iterator(3) s0 = s0.reorder(2, [tio, tjo, tk, tji, tii]) assert s0.stage(2).iterator(2).range.extent == 512 - s0 = s0.fuse(2, [tio, tjo]) - assert s0.stage(2).iterator(0).range.extent == 2048 + s0, res_it = s0.fuse(2, [tio, tjo]) + assert res_it.range.extent == 2048 - s1 = s1.split(2, ti, [8, 2]) - s1 = s1.split(2, tj, [32, 8], False) + s1, _ = s1.split(2, ti, [8, 2]) + s1, _ = s1.split(2, tj, [32, 8], False) assert s1.stage(2).iterator(0).range.extent == 32 assert s1.stage(2).iterator(1).range.extent == 8 assert s1.stage(2).iterator(2).range.extent == 2 @@ -186,22 +186,19 @@ def test_state_cache_read_write(): # 0: init state s0 = dag.get_init_state() ori_its = s0.stage(add).iterators() - s0 = s0.split(add, s0.stage(add).iterator(0), [2]) - s0 = s0.reorder(add, [s0.stage(add).iterator(0), ori_its[1], - s0.stage(add).iterator(1), ori_its[2], ori_its[3]]) + s0, its = s0.split(add, s0.stage(add).iterator(0), [2]) + s0 = s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]]) s0 = s0.compute_inline(relu) # 1: simple cache_write with compute_at - s0 = s0.cache_write(conv, "global", dag) - conv_global = conv + s0, conv_global = s0.cache_write(conv, "global", dag) conv += 1 relu += 1 add += 1 s0 = s0.compute_at(conv_global, conv, s0.stage(conv).iterator(3)) # 2: simple cache_read with compute_at - s0 = s0.cache_read(kernel, "global", [conv_global], dag) - kernel_global = kernel + 1 + s0, kernel_global = s0.cache_read(kernel, "global", [conv_global], dag) conv_global += 1 conv += 1 relu += 1 @@ -252,8 +249,7 @@ def test_state_cache_read_write(): # 3: two level cache_read with compute_at # preparing for GPU's shared memory & local memory - s0 = s0.cache_read(pad_temp, "global", [conv_global], dag) - pad_temp_global = pad_temp + 1 + s0, pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -262,8 +258,8 @@ def test_state_cache_read_write(): conv += 1 relu += 1 add += 1 - s0 = s0.cache_read(pad_temp_global, "shared", [conv_global], dag) - pad_temp_shared = pad_temp_global + 1 + s0, pad_temp_shared = s0.cache_read( + pad_temp_global, "shared", [conv_global], dag) kernel_data += 1 kernel_split += 1 kernel += 1 @@ -279,7 +275,7 @@ def test_state_cache_read_write(): # 4: cache_read with multi readers # This stage cannot be compute at to its consumer - s0 = s0.cache_read(data, "global", [pad_temp, add], dag) + s0, data_global = s0.cache_read(data, "global", [pad_temp, add], dag) pad_temp += 1 pad_temp_global += 1 pad_temp_shared += 1 @@ -350,7 +346,7 @@ def test_state_cache_read_write(): # 5: cache_write with multi outputs # See tests/cpp/ansor_test.cc for more information - s0 = s0.cache_write(kernel_split, "global", dag) + s0, _ = s0.cache_write(kernel_split, "global", dag) assert str(s0) == \ "Placeholder: Data, Kernel_data\n" + \ "for ax0 (0,4)\n" + \ @@ -424,40 +420,39 @@ def test_follow_split_follow_fused_split(): s0 = dag.get_init_state() C = 2 - s0 = s0.cache_write(C, "global", dag) - C_global = C + s0, C_global = s0.cache_write(C, "global", dag) C += 1 - s0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 2, 8, 4], True) split_step0 = s0.transform_steps_size() - 1 for level in range(1, 6): tmp = s0 - tmp = tmp.follow_split(C_global, tmp.stage( + tmp, _ = tmp.follow_split(C_global, tmp.stage( C_global).iterator(0), split_step0, level) for i in range(0, level): assert tmp.stage(C).iterator(i).range.extent == \ tmp.stage(C_global).iterator(i).range.extent - s0 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(5), [2, 2, 4, 8]) split_step1 = s0.transform_steps_size() - 1 - its = s0.stage(C).iterators() - s0 = s0.reorder(C, [its[0], its[5], its[1], its[6], its[2], its[7], - its[3], its[8], its[4], its[9]]) - s0 = s0.fuse(C, [s0.stage(C).iterator(0), s0.stage(C).iterator(1)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(1), s0.stage(C).iterator(2)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(2), s0.stage(C).iterator(3)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(3), s0.stage(C).iterator(4)]) - s0 = s0.fuse(C, [s0.stage(C).iterator(4), s0.stage(C).iterator(5)]) + its = [] + for i0, i1 in zip(its0, its1): + its.append(i0) + its.append(i1) + s0 = s0.reorder(C, its) + for i in range(0, 5): + s0, _ = s0.fuse(C, [s0.stage(C).iterator(i), + s0.stage(C).iterator(i+1)]) for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, False) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, False) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(0).range.extent for level in range(0, 4): tmp = s0 - tmp = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), - [split_step0, split_step1], level, True) + tmp, _ = tmp.follow_fused_split(C_global, tmp.stage(C_global).iterator(0), + [split_step0, split_step1], level, True) assert tmp.stage(C).iterator(level+1).range.extent == \ tmp.stage(C_global).iterator(1).range.extent @@ -466,6 +461,49 @@ def test_rfactor(): pass +def test_measure_local_builder_runner(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + + s0 = dag.get_init_state() + A, B, C = 0, 1, 2 + s0, C_global = s0.cache_write(C, "global", dag) + C += 1 + s0, its0 = s0.split(C, s0.stage(C).iterator(0), [4, 8, 8]) + s0, its1 = s0.split(C, s0.stage(C).iterator(4), [8, 4, 4]) + s0 = s0.reorder(C, [its0[0], its1[0], its0[1], its1[1], its0[2], its1[2], + its0[3], its1[3]]) + s0 = s0.compute_at(C_global, C, s0.stage(C).iterator(3)) + s0, _ = s0.split(C_global, s0.stage(C_global).iterator(2), [16]) + s0, B_global = s0.cache_read(B, "global", [C_global], dag) + C += 1 + C_global += 1 + s0 = s0.compute_at(B_global, C_global, s0.stage(C_global).iterator(0)) + s0, A_global = s0.cache_read(A, "global", [C_global], dag) + B += 1 + B_global += 1 + C += 1 + C_global += 1 + s0 = s0.compute_at(A_global, C_global, s0.stage(C_global).iterator(2)) + + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + minp = ansor.MeasureInput(task, s0) + local_builder = ansor.LocalBuilder() + local_runner = ansor.LocalRunner() + + bress = local_builder.build([minp]) + assert bress[0].error_no == 0 + mress = local_runner.run([minp], bress) + assert mress[0].error_no == 0 + + +def test_search_basic(): + dag = ansor.ComputeDAG(matmul_nkkm(512, 512, 512)) + tgt = tvm.target.create("llvm") + task = ansor.SearchTask(dag, "test", tgt) + + if __name__ == "__main__": test_compute_dag_basic() test_state_split_fuse_reorder() @@ -473,3 +511,5 @@ def test_rfactor(): test_state_cache_read_write() test_follow_split_follow_fused_split() test_rfactor() + test_measure_local_builder_runner() + # test_search_basic()