diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index ff93704edb442..f2819d1dad369 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -207,7 +207,18 @@ def set_task(self, task): for x in arg_bufs] func = build(s, arg_bufs, "llvm") tvm_buf = [nd.array(x) for x in self.ref_input] - func(*tvm_buf) + + def _run_func(): + """Run tvm function in a thread. + Because there is some issues with python multiprocessing and the thread pool in tvm + """ + func(*tvm_buf) + + thread = threading.Thread(target=_run_func) + thread.start() + thread.join() + del thread + self.ref_output = [x.asnumpy() for x in tvm_buf] def get_build_kwargs(self): diff --git a/python/tvm/autotvm/task/nnvm_integration.py b/python/tvm/autotvm/task/nnvm_integration.py index 6a07194a594d5..4d1551c57209f 100644 --- a/python/tvm/autotvm/task/nnvm_integration.py +++ b/python/tvm/autotvm/task/nnvm_integration.py @@ -5,6 +5,7 @@ """ import warnings import logging +import sys from ... import tensor, placeholder, create_schedule, target as _target @@ -49,9 +50,9 @@ def deserialize_args(args): # Task extractor for nnvm graph class TaskExtractEnv: """Global environment for extracting tuning tasks from nnvm graph""" - current = None + registered = False - def __init__(self): + def __init__(self, wanted_symbols): import topi import nnvm @@ -83,46 +84,62 @@ def __init__(self): topi.nn.dense: [topi.generic.schedule_dense], } - self._register_tracing() + # support reflection for tracing + self.func_to_reflection = { + topi.nn.conv2d: lambda x: setattr(topi.nn, 'conv2d', x), + topi.nn.depthwise_conv2d_nchw: lambda x: setattr(topi.nn, 'depthwise_conv2d_nchw', x), + topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x), + topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x), + } + + + self.wanted_topi_funcs = [] + for sym_name in wanted_symbols: + if sym_name in self.symbol2topi: + self.wanted_topi_funcs.extend(self.symbol2topi[sym_name]) + else: + warnings.warn("Symbol %s is not tunable, ignored" % sym_name) + self._register_topi_task() self.task_collection = [] - self.wanted_topi_funcs = list(self.topi_to_task.keys()) + self.modified_funcs = [] - def _register_tracing(self): - """Register tracing function to track the topi function call""" - # register topi compute for "tracing" target - for topi_compute in self.topi_to_task: + def __enter__(self): + self.task_collection = [] + self.modified_funcs = [] + + for topi_compute in self.wanted_topi_funcs: def _local_scope(compute_func): """start a scope to hold the local function in for loop""" - @compute_func.register("tracing", ) - def _tracing_topi_compute(*args, **kwargs): - assert not kwargs, "Do not support extracting tuning tasks when" \ - "kwargs is used in TOPI function call." \ + def _tracing_wrapper(*args, **kwargs): + assert not kwargs, "Do not support extracting tuning tasks when " \ + "kwargs is used in TOPI function call. " \ "Please modify it to use only positional args." - if compute_func in self.wanted_topi_funcs: # record this call - key = (self.topi_to_task[compute_func], serialize_args(args)) - if key not in self.task_collection: - self.task_collection.append(key) + key = (self.topi_to_task[compute_func], serialize_args(args)) + if key not in self.task_collection: + self.task_collection.append(key) + + return compute_func(*args, **kwargs) + + self.func_to_reflection[topi_compute](_tracing_wrapper) + self.modified_funcs.append(topi_compute) - return compute_func.fdefault(*args) _local_scope(topi_compute) - # register topi schedule for "tracing" target - for topi_compute in self.topi_to_task: - for topi_schedule in self.topi_to_schedule[topi_compute]: - def _local_scope_(schedule_func): - """start a scope to hold the local function in for loop""" + return self - @schedule_func.register("tracing", ) - def _tracing_topi_compute(outs): - outs = [outs] if isinstance(outs, tensor.Tensor) else outs - return create_schedule([x.op for x in outs]) - _local_scope_(topi_schedule) + def __exit__(self, exc_type, exc_val, exc_tb): + # revert modification + for func in self.modified_funcs: + self.func_to_reflection[func](func) def _register_topi_task(self): """register tuning wrapper for topi function""" + if TaskExtractEnv.registered: + return + TaskExtractEnv.registered = True import topi # Tuning wrapper for topi functions @@ -175,17 +192,6 @@ def _topi_nn_dense(*args, **kwargs): return s, [data, weight, bias, C] return s, [data, weight, C] - def reset(self, wanted_topi_funcs): - """Reset task collections - - Parameters - ---------- - wanted_topi_funcs: List of function - The topi function to be extracted - """ - self.task_collection = [] - self.wanted_topi_funcs = wanted_topi_funcs - def get_tasks(self): """Get collected tasks @@ -196,25 +202,11 @@ def get_tasks(self): """ return self.task_collection - @staticmethod - def get(): - """Get the single instance of TaskExtractEnv - - Returns - ------- - env: TaskExtractEnv - The single instance of TaskExtractEnv - """ - if not TaskExtractEnv.current: - TaskExtractEnv.current = TaskExtractEnv() - return TaskExtractEnv.current - def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): """ Extract tuning tasks from a nnvm graph. - This function collects tuning tasks by building the graph - with a "tracing" target and tracing all the calls to topi. + This function collects tuning tasks by building the graph and trace all the calls to topi. Parameters ---------- @@ -237,97 +229,34 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None): collected tasks """ import nnvm.compiler + import topi - env = TaskExtractEnv.get() + env = TaskExtractEnv(symbols) - topi_funcs = [] - for sym_name in symbols: - if sym_name in env.symbol2topi: - topi_funcs.extend(env.symbol2topi[sym_name]) - else: - warnings.warn("Symbol %s is not tunable, ignored" % sym_name) + with env: + # disable logger temporarily + old_state = logger.disabled + logger.disabled = True - # run compiler to collect all TOPI calls during compilation - env.reset(topi_funcs) + # run compiler to collect all TOPI calls during compilation + nnvm.compiler.engine.clear_cache() + nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype) + nnvm.compiler.engine.clear_cache() - # disable logger temporarily - old_state = logger.disabled - logger.disabled = True - - # use a "tracing" target to do a fake compile for collecting topi calls - tracing_target = _target.create("llvm -device=tracing") - nnvm.compiler.engine.clear_cache() - nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype) - - logger.disabled = old_state + logger.disabled = old_state # create tasks for target tasks = [] for task_name, args in env.get_tasks(): - tasks.append(create(task_name, args, - target=target, target_host=target_host, - template_key='direct')) + try: + tsk = create(task_name, args, + target=target, target_host=target_host, + template_key='direct') + tasks.append(tsk) + except topi.InvalidShapeError: + print("shape error") return tasks - -def extract_from_multiple_graph(graphs, shapes, dtypes, target, symbols, target_host=None): - """ Extract tuning tasks from multiple nnvm graphs. - - This function is the multiple graph version of extract_from_graph - - Parameters - ---------- - graphs : List of Graph - The list of graphs to tune - shapes : List of dict of str to tuple - The input shape to the graph - dtypes : List of str or dict of str to str - The input types to the graph - target: tvm.target.Target - The compilation target - symbols : Array of nnvm.symbol - Array of nnvm symbols want to be tuned - target_host: tvm.target.Target - The host compilation target - - Returns - ------- - task: Array of autotvm.task.Task - collected tasks - """ - import nnvm.compiler - - env = TaskExtractEnv.get() - - topi_funcs = [] - for sym_name in symbols: - if sym_name in env.symbol2topi: - topi_funcs.extend(env.symbol2topi[sym_name]) - else: - warnings.warn("Symbol %s is not tunable, ignored" % sym_name) - - # run compiler to collect all TOPI calls during compilation - env.reset(topi_funcs) - - # disable logger temporarily - old_state = logger.disabled - logger.disabled = True - - # use a "tracing" target to do a fake compile for collecting topi calls - tracing_target = _target.create("llvm -device=tracing") - - nnvm.compiler.engine.clear_cache() - for graph, shape, dtype in zip(graphs, shapes, dtypes): - nnvm.compiler.build(graph, target=tracing_target, shape=shape, dtype=dtype) - - logger.disabled = old_state - - # create tasks for target - tasks = [] - for task_name, args in env.get_tasks(): - tasks.append(create(task_name, args, - target=target, target_host=target_host, - template_key='direct')) - - return tasks +def extract_from_multiple_graph(graph, shape, dtype, target, symbols, target_host=None): + pass diff --git a/python/tvm/target.py b/python/tvm/target.py index 75f82743f9fae..23bec3adddda7 100644 --- a/python/tvm/target.py +++ b/python/tvm/target.py @@ -473,6 +473,13 @@ def rasp(options=None): return arm_cpu('rasp3b', options) +def vta(model='unknown', options=None): + opts = ["-device=vta", '-keys=cpu', '-model=%s' % model] + opts = _merge_opts(opts, options) + ret = _api_internal._TargetCreate("ext_dev", *opts) + return ret + + def create(target_str): """Get a target given target string. diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 0659a07f2520a..f20c0c5eeb195 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -39,6 +39,7 @@ Target CreateTarget(const std::string& target_name, std::string libs_flag = "-libs="; std::string device_flag = "-device="; + std::string keys_flag = "-keys="; for (auto& item : options) { t->options_array.push_back(ir::StringImm::make(item)); @@ -50,12 +51,16 @@ Target CreateTarget(const std::string& target_name, } } else if (item.find(device_flag) == 0) { t->device_name = item.substr(device_flag.length()); + t->keys_array.push_back(ir::StringImm::make(t->device_name)); + } else if (item.find(keys_flag) == 0) { + std::stringstream ss(item.substr(keys_flag.length())); + std::string key_item; + while (std::getline(ss, key_item, ',')) { + t->keys_array.push_back(ir::StringImm::make(key_item)); + } } } - if (t->device_name.length() > 0) { - t->keys_array.push_back(ir::StringImm::make(t->device_name)); - } t->device_type = kDLCPU; t->thread_warp_size = 1; if (target_name == "llvm") { diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py index 2eb460d151ae5..ed03da110c1f3 100644 --- a/topi/python/topi/__init__.py +++ b/topi/python/topi/__init__.py @@ -34,6 +34,10 @@ from . import image from . import sparse from . import hls + +# some short cut +from .util import InvalidShapeError + # not import testing by default # because testing can have extra deps that are not necessary # we can import them from test cases explicitly diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index de9ff90ae26ba..6b58c9993cfe0 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -6,6 +6,10 @@ import tvm from . import tag +class InvalidShapeError(ValueError): + """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" + pass + def traverse_inline(s, final_op, callback): """Traverse computation graph and do auto inline diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 18f501f4ebb1a..e765fa39e18bb 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -18,6 +18,7 @@ # to maintain minimum dependency on the board if sys.argv[0] not in ("-c", "-m"): from . import top - from .build_module import build_config, lower, build from . import graph + + from .build_module import build_config, lower, build, vta_autotvm_build_func from .ptr_alias import reinterpret diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py index a1d2299ba7aa6..0d6ee7fcf6a56 100644 --- a/vta/python/vta/build_module.py +++ b/vta/python/vta/build_module.py @@ -102,3 +102,39 @@ def build(*args, **kwargs): with build_config(): return tvm.build(*args, **kwargs) return tvm.build(*args, **kwargs) + + +def vta_autotvm_build_func(measure_input, tmp_dir, **kwargs): + """Custom build func for VTA. Used for autotvm""" + + import time + import os + from random import getrandbits + from tvm.autotvm.util import get_const_tuple + from tvm.autotvm.measure.measure_methods import BuildResult, InstantiationError + + tic = time.time() + try: + filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) + target, task, config = measure_input + + with target: + s, args = task.instantiate(config) + if not config.valid(): + raise InstantiationError(config.errors) + + func = build(s, args, target_host=task.target_host) + func2 = build(s, args) + + arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args) + func.export_library(filename) + + # check by local simulator + ctx = tvm.context(str(target)) + args = [tvm.nd.empty(x[0], dtype=x[1], ctx=ctx) for x in arg_info] + func2(*args) + + except Exception as e: # pylint: disable=broad-except + return BuildResult(None, None, e, time.time() - tic) + return BuildResult(filename, arg_info, None, time.time() - tic) + diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index d8590ce74d31c..362adf69686b8 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -227,35 +227,25 @@ def gemm(self): """GEMM intrinsic""" return self.dev.gemm - # TODO get rid of it @property - def target_host(self): - """The target host""" - return "llvm " + self.llvm_triple + def target(self): + return tvm.target.vta(model=self.TARGET) @property - def target_vta_cpu(self): + def target_host(self): """The target host""" if self.TARGET == "pynq": - return "llvm -device=arm_cpu -model=pynq {}".format(self.llvm_triple) + return "llvm -target=armv7-none-linux-gnueabihf" elif self.TARGET == "ultra96": - return "llvm -device=arm_cpu -model=ultra96 {}".format(self.llvm_triple) + return "llvm -target=aarch64-linux-gnu" elif self.TARGET == "sim": return "llvm" else: raise ValueError("Unknown target %s" % self.TARGET) @property - def llvm_triple(self): - """The llvm flags for the target platform""" - if self.TARGET == "pynq": - return "-target=armv7-none-linux-gnueabihf" - elif self.TARGET == "ultra96": - return "-target=aarch64-linux-gnu" - elif self.TARGET == "sim": - return "" - else: - raise ValueError("Unknown target %s" % self.TARGET) + def target_vta_cpu(self): + return tvm.target.arm_cpu(model=self.TARGET) def get_env(): diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py index 3e2563285434d..ac1e7c60ebdc6 100644 --- a/vta/python/vta/top/__init__.py +++ b/vta/python/vta/top/__init__.py @@ -1,11 +1,4 @@ """TVM TOPI connector, eventually most of these should go to TVM repo""" from . import vta_conv2d -from . import arm_conv2d -from . import testing - -from .bitpack import bitpack -from .vta_dense import packed_dense, schedule_packed_dense -from .vta_conv2d import packed_conv2d, schedule_packed_conv2d -from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d -from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose +from . import op diff --git a/vta/python/vta/top/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py deleted file mode 100644 index 506a47a4c3443..0000000000000 --- a/vta/python/vta/top/arm_conv2d.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Reuse conv2d schedule from ARM CPU""" - -import tvm - -from topi.nn import conv2d, conv2d_alter_layout -from topi import generic - -@conv2d.register(["vta"]) -def compute(*args, **kwargs): - target = tvm.target.current_target() - with tvm.target.arm_cpu(model=target.model): - return conv2d(*args, **kwargs) - -@generic.schedule_conv2d_nchw.register(["vta"]) -def schedule(*args, **kwargs): - target = tvm.target.current_target() - with tvm.target.arm_cpu(model=target.model): - return generic.schedule_conv2d_nchw(*args, **kwargs) - -@conv2d_alter_layout.register(["vta"]) -def alter(*args, **kwargs): - target = tvm.target.current_target() - with tvm.target.arm_cpu(model=target.model): - return conv2d_alter_layout(*args, **kwargs) diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py new file mode 100644 index 0000000000000..34897ac837801 --- /dev/null +++ b/vta/python/vta/top/op.py @@ -0,0 +1,153 @@ +"""Namespace for supporting packed_conv2d + ewise variant of nnvm.""" +from __future__ import absolute_import as _abs + +from collections import namedtuple + +import logging + +import tvm +from tvm import autotvm +import topi + +from nnvm.top import registry as reg, OpPattern +from nnvm.top import nn as _nn + +from ..environment import get_env +from ..ptr_alias import reinterpret + +from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d +from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose + +@reg.register_compute("clip", level=15) +def compute_clip(attrs, inputs, _): + """ Clip operator. """ + x = inputs[0] + a_min = attrs.get_float("a_min") + a_max = attrs.get_float("a_max") + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + with tvm.tag_scope(topi.tag.ELEMWISE): + x = tvm.compute( + x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute( + x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + +# override to force partition at copy +reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) + +def is_packed_layout(layout): + """Check if layout is packed layout""" + if layout == "NCHW": + return False + if "n" in layout and "c" in layout: + return True + return False + +@reg.register_compute("conv2d", level=15) +def compute_conv2d(attrs, inputs, out): + """ 2D convolution algorithm. + """ + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs["layout"] + out_dtype = attrs['out_dtype'] + + assert dilation == (1, 1), "not support dilate now" + if is_packed_layout(layout): + if groups == 1: + assert groups == 1 + env = get_env() + assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" + inputs = list(inputs) + w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) + assert inputs[1].dtype == "int8" + + # Apply bit packing if necessary + if w_pack_factor != 1: + kshape = list(topi.util.get_const_tuple(inputs[1].shape)) + kshape[-1] *= w_pack_factor + inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype) + + return topi.nn.conv2d(inputs[0], inputs[1], strides, padding, layout, out_dtype) + else: + return packed_group_conv2d(inputs[0], inputs[1], padding, strides, groups, out_dtype) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_conv2d(attrs, inputs, out) + + +@reg.register_schedule("conv2d", level=15) +def schedule_conv2d(attrs, outs, target): + """ 2D convolution schedule. + """ + layout = attrs["layout"] + groups = attrs.get_int('groups') + + if is_packed_layout(layout): + target = tvm.target.create(target) + if target.device_name == "vta": + if groups == 1: + return topi.generic.schedule_conv2d_nchw(outs) + else: + return schedule_packed_group_conv2d(outs) + elif str(target).startswith("llvm"): + return tvm.create_schedule([x.op for x in outs]) + else: + raise RuntimeError("not support target %s" % target) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_conv2d(attrs, outs, tvm.target.current_target()) + + +@reg.register_alter_op_layout("conv2d", level=15) +def alter_conv2d_layout(attrs, inputs, out): + layout = attrs['layout'] + if is_packed_layout(layout): + return None + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.alter_conv2d_layout(attrs, inputs, out) + + +@reg.register_compute("conv2d_transpose", level=15) +def compute_conv2d_transpose(attrs, inputs, out): + """ 2D convolution algorithm. + """ + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + layout = attrs["layout"] + out_dtype = attrs['out_dtype'] + + assert dilation == (1, 1), "not support dilate now" + if is_packed_layout(layout): + return packed_conv2d_transpose(inputs[0], inputs[1], + padding, strides, + out_dtype=out_dtype) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.compute_conv2d_transpose(attrs, inputs, out) + + +@reg.register_schedule("conv2d_transpose", level=15) +def schedule_conv2d_transpose(attrs, outs, target): + """ 2D convolution schedule. + """ + layout = attrs["layout"] + + if is_packed_layout(layout): + target = tvm.target.create(target) + if target.device_name == "vta": + return schedule_packed_conv2d_transpose(outs) + elif str(target).startswith("llvm"): + return tvm.create_schedule([x.op for x in outs]) + else: + raise RuntimeError("not support target %s" % target) + + with tvm.target.arm_cpu(tvm.target.current_target().model): + return _nn.schedule_conv2d_transpose(attrs, outs, target) + diff --git a/vta/python/vta/top/testing/__init__.py b/vta/python/vta/top/testing/__init__.py deleted file mode 100644 index 0ba1567d21d25..0000000000000 --- a/vta/python/vta/top/testing/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""TVM TOPI test, used by integration tests""" - -from .vta_conv2d_test import run_vta_conv2d, my_clip \ No newline at end of file diff --git a/vta/python/vta/top/testing/vta_conv2d_test.py b/vta/python/vta/top/testing/vta_conv2d_test.py deleted file mode 100644 index 5116d1f260238..0000000000000 --- a/vta/python/vta/top/testing/vta_conv2d_test.py +++ /dev/null @@ -1,220 +0,0 @@ -import tvm -from tvm import autotvm -from tvm.contrib import util -from tvm.contrib.pickle_memoize import memoize -import topi -import topi.testing -import vta -import vta.testing -from vta.testing import simulator -import numpy as np - -import os -import json - -def _sign_extend(value, bits): - sign_bit = 1 << (bits - 1) - return (value & (sign_bit - 1)) - (value & sign_bit) - -_vector_sign_extend = np.vectorize(_sign_extend) - -def _pack(x, width): - assert(len(x.shape)==6) - assert(x.dtype=="int8") - pack_factor = 8 // width - mask = ((1 << width) - 1) - - s = x.shape - s_reshape = s[:-1] + (s[-1] // pack_factor, pack_factor) - s_pack = s[:-1] + (s[-1] // pack_factor,) - x_reshape = x.reshape(s_reshape) - x_packed = np.zeros(s_pack, dtype="int8") - for i in range(0, pack_factor): - x_packed |= (x_reshape[:,:,:,:,:,:,i] & mask) << (i * width) - - return x_packed - -def _unpack(x, width): - assert(len(x.shape)==6) - assert(x.dtype=="int8") - pack_factor = 8 // width - mask = ((1 << width) - 1) - - s = x.shape - x_unpack = np.zeros(s[:] + (pack_factor,), dtype=x.dtype) - for i in range(0, pack_factor): - x_unpack[:,:,:,:,:,:,i] = _vector_sign_extend(((x >> (i * width)) & mask), width) - - return x_unpack.reshape(s[:-1] + (s[-1] * pack_factor,)) - -@tvm.tag_scope(tag=topi.tag.ELEMWISE) -def my_clip(x, a_min, a_max): - """Unlike topi's current clip, put min and max into two stages.""" - const_min = tvm.const(a_min, x.dtype) - const_max = tvm.const(a_max, x.dtype) - x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") - x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") - return x - -def run_vta_conv2d(env, remote, wl, print_ir=False, - plan_str=None, samples=5, profileOnly=False, - skip_load_inp=False, skip_load_wgt=False, skip_load_acc=False, - skip_store_out=False, skip_alu=False, skip_gemm=False): - data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, - wl.height, wl.width, env.BATCH, env.BLOCK_IN) - kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, - wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) - bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, - 1, 1, env.BATCH, env.BLOCK_OUT) - data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) - - # Handle quantized inputs (less than 8 bits) - # x_pack_factor = 1 << (3 - env.LOG_INP_WIDTH) - # data_shape_pack = data_shape[:-1] + (data_shape[-1]//x_pack_factor,) - # data_arg = tvm.placeholder( - # data_shape_pack, - # dtype="int8", name="data_arg") - # data = vta.reinterpret(data_arg, data_shape, dtype=env.inp_dtype) - - # Handle quantized kernels (less than 8 bits) - w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) - kernel_shape_pack = kernel_shape[:-1] + (kernel_shape[-1]//w_pack_factor,) - kernel_arg = tvm.placeholder( - kernel_shape_pack, - dtype="int8", name="kernel_arg") - kernel = vta.reinterpret(kernel_arg, kernel_shape, dtype=env.wgt_dtype) - - res_conv = vta.top.packed_conv2d( - data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) - res = topi.right_shift(res_conv, 8) - res = topi.add(res, bias) - res = my_clip(res, 0, (1 << env.OUT_WIDTH-1)-1) - res = topi.cast(res, "int8") - - # Handle quantized outputs (less than 8 bits) - # o_pack_factor = 1 << (3 - env.LOG_OUT_WIDTH) - res_shape = topi.util.get_const_tuple(res.shape) - # res_shape_pack = res_shape[:-1] + (res_shape[-1]//o_pack_factor,) - # res_arg = vta.reinterpret(res, res_shape_pack, dtype="int8") - - # To compute number of ops, use a x2 factor for FMA - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - - a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) - w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - stride = (wl.hstride, wl.wstride) - data_dtype = data.dtype - kernel_dtype = kernel.dtype - acc_dtype = env.acc_dtype - assert wl.hpad == wl.wpad - padding = wl.hpad - - # @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") - def get_ref_data(): - # derive min max for input and weight types (max non inclusive) - a_min, a_max = 0 - (1 << (env.INP_WIDTH - 1)), (1 << (env.INP_WIDTH - 1)) - w_min, w_max = 0 - (1 << (env.WGT_WIDTH - 1)), (1 << (env.WGT_WIDTH - 1)) - a_np = np.random.randint( - a_min, a_max, size=a_shape).astype("int8") - w_np = np.random.randint( - w_min, w_max, size=w_shape).astype("int8") - b_np = topi.testing.conv2d_nchw_python( - a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) - return a_np, w_np, b_np - - def verify(s, check_correctness): - mod = vta.build(s, - [data, kernel_arg, bias, res], - "ext_dev", - env.target_host, name="conv2d") - temp = util.tempdir() - - mod.save(temp.relpath("conv2d.o")) - remote.upload(temp.relpath("conv2d.o")) - f = remote.load_module("conv2d.o") - # verify - ctx = remote.ext_dev(0) - # Data in original format - data_orig, kernel_orig, res_ref = get_ref_data() - bias_orig = (np.random.uniform(size=(wl.batch, wl.out_filter,)) * (1 << (env.INP_WIDTH + env.WGT_WIDTH - 2))) - bias_orig = bias_orig.astype("int32") - bias_orig = np.abs(bias_orig) - - data_packed = data_orig.reshape( - wl.batch//env.BATCH, env.BATCH, - wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, - wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) - kernel_packed = kernel_orig.reshape( - wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, - wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, - wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) - bias_packed = bias_orig.reshape( - wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, - 1, 1, env.BATCH, env.BLOCK_OUT) - - # Quantized packing - data_qpacked = _pack(data_packed, env.INP_WIDTH) - kernel_qpacked = _pack(kernel_packed, env.WGT_WIDTH) - - res_np = np.zeros(res_shape).astype(res.dtype) - data_arr = tvm.nd.array(data_qpacked, ctx) - kernel_arr = tvm.nd.array(kernel_qpacked, ctx) - bias_arr = tvm.nd.array(bias_packed, ctx) - res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d", ctx, number=1, repeat=samples) - - # In sim mode, collect simulator runtime statistics - stats = {} - cost = None - if env.TARGET == "sim": - # Check if we're in local RPC mode (allows us to rebuild the - # runtime on the fly when varying the VTA designs) - local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) - if local_rpc: - remote.get_function("vta.simulator.profiler_clear")() - if profileOnly: - remote.get_function("vta.simulator.profiler_debug_mode")(1) - cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) - stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) - else: - simulator.clear_stats() - if profileOnly: - simulator.debug_mode(1) - cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) - stats = simulator.stats() - else: - cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) - - # Check correctness - correct = False - if check_correctness: - res_unpack = res_arr.asnumpy() - res_unpack = _unpack(res_unpack.astype("int8"), env.OUT_WIDTH) - res_unpack = res_unpack.transpose( - (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) - assert wl.hpad == wl.wpad - stride = (wl.hstride, wl.wstride) - padding = wl.hpad - res_ref = res_ref >> 8 - res_ref += bias_orig.reshape(wl.out_filter, 1, 1) - res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH-1)-1) - res_ref = res_ref.astype("int8") - correct = np.allclose(res_unpack, res_ref) - - return (correct, cost, stats) - - with vta.build_config(): - s = vta.top.schedule_packed_conv2d([res], - planStr=plan_str, - skip_load_inp=skip_load_inp, - skip_load_wgt=skip_load_wgt, - skip_load_acc=skip_load_acc, - skip_store_out=skip_store_out, - skip_alu=skip_alu, - skip_gemm=skip_gemm) - if print_ir: - print(vta.lower(s, [data, kernel_arg, bias, res], simple_mode=True)) - return verify(s, profileOnly is False) - diff --git a/vta/python/vta/top/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py index 767586a998124..3e517f98e9c7a 100644 --- a/vta/python/vta/top/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -1,182 +1,29 @@ """Namespace for supporting packed_conv2d + ewise variant of nnvm.""" -from __future__ import absolute_import as _abs -from collections import namedtuple - -import logging import tvm +from tvm import autotvm import topi -import re - -from nnvm.top import registry as reg, OpPattern -from nnvm.top import nn as _nn -from ..environment import get_env -from ..ptr_alias import reinterpret -from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d -from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose - -Workload = namedtuple("Conv2DWorkload", - ['batch', 'height', 'width', 'in_filter', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) - -_SCHEDULE_STR_MAP = {} - - -def find_schedules(layer, vt_only=False, best_only=False): - """ Returns a schedule for a given a layer. - - Parameters - ---------- - layer : Workload - Convolutional layer description. - vt_only : Boolean - Produce a schedule plan with virtual threading. - best_only : Boolean - Return the "best" schedule plan. - - Returns - ------- - fil_sched : list - List of valid schedules. - - """ - # pylint: disable=too-many-nested-blocks - env = get_env() - - # Helper function to get factors - def _find_factors(n): - factors = [] - for f in range(1, n + 1): - if n % f == 0: - factors.append(f) - return factors - - def _get_data_movement_byte(schedule, layer): - """ Estimate data movement in bytes for the schedule plan - """ - env = get_env() - b_f = schedule.b_factor - h_f = schedule.h_factor - w_f = schedule.w_factor - ci_f = schedule.ic_factor - co_f = schedule.oc_factor - # Derive data movement - inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH - wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH - out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH - input_tile_elems = b_f * \ - ((h_f - 1) * layer.hstride + layer.hkernel) * \ - ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f - weight_tile_elems = layer.hkernel * layer.wkernel * ci_f - output_tile_elems = b_f * h_f * w_f * co_f - # Derive tiling factors - b_factor = layer.batch // (b_f * env.BATCH) - h_factor = (layer.height // layer.hstride) // h_f - w_factor = (layer.width // layer.wstride) // w_f - ci_factor = layer.in_filter // (ci_f * env.BLOCK_IN) - co_factor = layer.out_filter // (co_f * env.BLOCK_OUT) - # Compute input transaction count - input_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor - weight_xfers = b_factor * h_factor * w_factor * co_factor * ci_factor - output_xfers = b_factor * h_factor * w_factor * co_factor - # Compute total transfer sizes - input_xfer_byte = input_tile_elems * input_xfers * inp_elem_sizeb // 8 - weight_xfer_byte = weight_tile_elems * weight_xfers * wgt_elem_sizeb // 8 - output_xfer_byte = output_tile_elems * output_xfers * out_elem_sizeb // 8 - total_xfer_byte = input_xfer_byte + weight_xfer_byte + output_xfer_byte - return total_xfer_byte - - # Scheduling exploration - OH = (layer.height + 2 * layer.hpad - layer.hkernel) // layer.hstride + 1 - OW = (layer.width + 2 * layer.wpad - layer.wkernel) // layer.wstride + 1 - batch_factors = _find_factors(layer.batch // env.BATCH) - height_factors = _find_factors(OH) - width_factors = _find_factors(OW) - cin_factors = _find_factors(layer.in_filter // env.BLOCK_IN) - cout_factors = _find_factors(layer.out_filter // env.BLOCK_OUT) - ht_factors = [1, 2] - cot_factors = [1, 2] - - # Explore schedules - schedules = [] - for b_f in batch_factors: - for h_f in height_factors: - for w_f in width_factors: - for ci_f in cin_factors: - for co_f in cout_factors: - # FIXME: 2D load pattern matching imposes restrictions on schedule - valid = (w_f == layer.width // layer.wstride) or \ - (w_f != layer.width // layer.wstride and co_f == 1) and \ - ci_f == 1 - if valid: - schedules.append([b_f, h_f, w_f, ci_f, co_f]) - # Filter the schedules that wouldn't work in the available BRAM sizes - inp_elem_sizeb = env.BATCH * env.BLOCK_IN * env.INP_WIDTH - wgt_elem_sizeb = env.BLOCK_IN * env.BLOCK_OUT * env.WGT_WIDTH - out_elem_sizeb = env.BATCH * env.BLOCK_OUT * env.OUT_WIDTH - inp_brams_sizeb = env.INP_BUFF_SIZE * 8 - wgt_brams_sizeb = env.WGT_BUFF_SIZE * 8 - out_brams_sizeb = env.OUT_BUFF_SIZE * 8 - fil_sched = [] - xfer_size = [] - for sched in schedules: - b_f, h_f, w_f, ci_f, co_f = sched - for h_t in ht_factors: - for co_t in cot_factors: - # Make sure to filter cases where we apply threading on two axes - # or cases where the threading factors for h and co are not - # factors of h and co - if (h_t == 2 and co_t == 2) or (h_f % h_t != 0) or (co_f % co_t != 0): - continue - # Adjust tile sizes if threading is applied - h_f //= h_t - co_f //= co_t - # Derive tile sizes - input_tile_elems = b_f * \ - ((h_f - 1) * layer.hstride + layer.hkernel) * \ - ((w_f - 1) * layer.wstride + layer.wkernel) * ci_f - weight_tile_elems = layer.hkernel * layer.wkernel * ci_f * co_f - output_tile_elems = b_f * h_f * w_f * co_f +import numpy as np - # Derive valid schedule filter - valid = True - # If in vitrual-threaded mode, only allow for threaded plans - valid &= (vt_only and (h_t == 2 or co_t == 2)) or not vt_only - # Check that we don't exceed input/weight/output capacity - valid &= input_tile_elems * inp_elem_sizeb <= inp_brams_sizeb // (co_t * h_t) - valid &= weight_tile_elems * wgt_elem_sizeb <= wgt_brams_sizeb // (co_t * h_t) - valid &= output_tile_elems * out_elem_sizeb <= out_brams_sizeb // (co_t * h_t) - # Make sure that we don't write to the same acc location within 2 consecutive cycles - valid &= h_f > 2 and w_f > 2 - # TODO: check that we don't exceed instruction or micro-op count - - if valid: - schedule = Schedule(b_factor=b_f, oc_factor=co_f, ic_factor=ci_f, h_factor=h_f, - w_factor=w_f, oc_nthread=co_t, h_nthread=h_t) - fil_sched.append(schedule) - xfer_size.append(_get_data_movement_byte(schedule, layer)) +from ..environment import get_env +from .op import is_packed_layout - if best_only: - return [fil_sched[xfer_size.index(min(xfer_size))]] - return fil_sched +@autotvm.register_topi_compute(topi.nn.conv2d, 'vta', 'direct') +def packed_conv2d(cfg, data, kernel, strides, padding, layout, out_dtype): + """ Packed conv2d function.""" + if not is_packed_layout(layout): + raise topi.InvalidShapeError() -def packed_conv2d(data, - kernel, - padding, - strides, - out_dtype="int32"): - """ Packed conv2d function. - """ if padding[0]: pad_data = topi.nn.pad(data, [0, 0, padding[0], padding[1], 0, 0], name="pad_data") else: pad_data = data assert len(data.shape) == 6 assert len(kernel.shape) == 6 - oheight = topi.util.simplify((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) - owidth = topi.util.simplify((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) + oheight = topi.util.get_const_int((pad_data.shape[2] - kernel.shape[2]) // strides[0] + 1) + owidth = topi.util.get_const_int((pad_data.shape[3] - kernel.shape[3]) // strides[1] + 1) oshape = (data.shape[0], kernel.shape[0], oheight, owidth, data.shape[4], kernel.shape[4]) ishape = topi.util.get_const_tuple(data.shape) @@ -193,200 +40,16 @@ def packed_conv2d(data, kernel[c_o, k_o, d_i, d_j, c_i, k_i].astype(out_dtype), axis=[k_o, d_i, d_j, k_i]), name="res", tag="packed_conv2d") - return res - -@tvm.register_func("nnvm.compiler.build_target", override=True) -def _build(funcs, target, target_host): - tvm_t = tvm.target.create(target) - if tvm_t.device_name == "vta": - return tvm.build(funcs, target="ext_dev", target_host=target_host) - elif tvm_t.device_name == "rasp" or tvm_t.device_name == "vtacpu": - return tvm.build(funcs, target=target_host) - return tvm.build(funcs, target=target) - - -@tvm.register_func("nnvm.compiler.lower", override=True) -def _lower(sch, inputs, func_name, graph): - import traceback - # pylint: disable=broad-except - try: - f = tvm.lower(sch, inputs, name=func_name) - if "quantized_conv2d" in func_name: - logging.info(graph.ir(join_entry_attrs=["shape"])) - except Exception: - msg = traceback.format_exc() - msg += "Error during compile graph\n" - msg += "--------------------------\n" - msg += graph.ir(join_entry_attrs=["shape"]) - raise RuntimeError(msg) - return f if isinstance( - f, (tvm.container.Array, tuple, list)) else [f] - - -@reg.register_compute("clip", level=15) -def compute_clip(attrs, inputs, _): - """ Clip operator. - """ - x = inputs[0] - a_min = attrs.get_float("a_min") - a_max = attrs.get_float("a_max") - const_min = tvm.const(a_min, x.dtype) - const_max = tvm.const(a_max, x.dtype) - with tvm.tag_scope(topi.tag.ELEMWISE): - x = tvm.compute( - x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") - x = tvm.compute( - x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") - return x - -# override to force partition at copy -reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) - -def is_packed_layout(layout): - """Check if layout is packed layout""" - if layout == "NCHW": - return False - if "n" in layout and "c" in layout: - return True - return False - -@reg.register_alter_op_layout("conv2d", level=15) -def alter_conv2d_layout(attrs, inputs, out): - layout = attrs['layout'] - if is_packed_layout(layout): - return None - return _nn.alter_conv2d_layout(attrs, inputs, out) - - -@reg.register_compute("conv2d", level=15) -def compute_conv2d(attrs, inputs, out): - """ 2D convolution algorithm. - """ - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - dilation = attrs.get_int_tuple("dilation") - groups = attrs.get_int("groups") - layout = attrs["layout"] - out_dtype = attrs['out_dtype'] - - assert dilation == (1, 1), "not support dilate now" - if is_packed_layout(layout): - if groups == 1: - assert groups == 1 - env = get_env() - assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" - assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now" - inputs = list(inputs) - w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) - assert inputs[1].dtype == "int8" - - # Apply bit packing if necessary - if w_pack_factor != 1: - kshape = list(topi.util.get_const_tuple(inputs[1].shape)) - kshape[-1] *= w_pack_factor - inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype) - - return packed_conv2d(inputs[0], inputs[1], - padding, strides, out_dtype=out_dtype) - else: - return packed_group_conv2d(inputs[0], inputs[1], - padding, strides, groups, out_dtype=out_dtype) - return _nn.compute_conv2d(attrs, inputs, out) - - -@reg.register_schedule("conv2d", level=15) -def schedule_conv2d(attrs, outs, target): - """ 2D convolution schedule. - """ - layout = attrs["layout"] - groups = attrs.get_int('groups') - - if is_packed_layout(layout): - target = tvm.target.create(target) - if target.device_name == "vta": - if groups == 1: - return schedule_packed_conv2d(outs) - else: - return schedule_packed_group_conv2d(outs) - elif str(target).startswith("llvm"): - return tvm.create_schedule([x.op for x in outs]) - else: - raise RuntimeError("not support target %s" % target) - return _nn.schedule_conv2d(attrs, outs, target) - - -@reg.register_compute("conv2d_transpose", level=15) -def compute_conv2d_transpose(attrs, inputs, out): - """ 2D convolution algorithm. - """ - padding = attrs.get_int_tuple("padding") - strides = attrs.get_int_tuple("strides") - dilation = attrs.get_int_tuple("dilation") - layout = attrs["layout"] - out_dtype = attrs['out_dtype'] - - assert dilation == (1, 1), "not support dilate now" - if is_packed_layout(layout): - return packed_conv2d_transpose(inputs[0], inputs[1], - padding, strides, - out_dtype=out_dtype) - return _nn.compute_conv2d_transpose(attrs, inputs, out) + cfg.add_flop(2 * np.prod(topi.util.get_const_tuple(oshape)) * + kshape[2] * kshape[3] * ishape[1] * ishape[-1]) + return res -@reg.register_schedule("conv2d_transpose", level=15) -def schedule_conv2d_transpose(attrs, outs, target): - """ 2D convolution schedule. - """ - layout = attrs["layout"] - if is_packed_layout(layout): - target = tvm.target.create(target) - if target.device_name == "vta": - return schedule_packed_conv2d_transpose(outs) - elif str(target).startswith("llvm"): - return tvm.create_schedule([x.op for x in outs]) - else: - raise RuntimeError("not support target %s" % target) - return _nn.schedule_conv2d_transpose(attrs, outs, target) - -def _get_workload(data, pad_data, kernel, output): - """ Get the workload structure. - """ - o_shape = topi.util.get_const_tuple(output.shape) - d_shape = topi.util.get_const_tuple(data.shape) - k_shape = topi.util.get_const_tuple(kernel.shape) - o_b, o_c, o_h, o_w, ob_blk, o_blk = o_shape - i_b, i_c, i_h, i_w, ib_blk, i_blk = d_shape - k_o, k_i, k_h, k_w, ko_blk, ki_blk = k_shape - # For now we need to assume that input channel blocking is the same - # as the output channel blocking - assert o_blk == i_blk - assert ob_blk == ib_blk - # Make sure that dimensions match - assert o_b == i_b - assert o_blk == ko_blk - assert i_blk == ki_blk - assert k_o == o_c - assert k_i == i_c - # Scale the channel size - i_b *= ib_blk - i_c *= i_blk - o_c *= o_blk - if pad_data is not None: - p_shape = topi.util.get_const_tuple(pad_data.shape) - h_pad = (p_shape[2] - d_shape[2]) // 2 - w_pad = (p_shape[3] - d_shape[3]) // 2 - else: - h_pad, w_pad = 0, 0 - h_str = (i_h + h_pad*2 - k_h) // (o_h - 1) - w_str = (i_w + w_pad*2 - k_w) // (o_w - 1) - return Workload(i_b, i_h, i_w, i_c, o_c, k_h, k_w, h_pad, w_pad, h_str, w_str) - -def schedule_packed_conv2d(outs, planStr=None, skip_load_inp=False, skip_load_wgt=False, - skip_load_acc=False, skip_store_out=False, skip_alu=False, - skip_gemm=False): - """ Schedule the packed conv2d. - """ +@autotvm.register_topi_schedule(topi.generic.schedule_conv2d_nchw, 'vta', 'direct') +def schedule_packed_conv2d(cfg, outs, + skip_load_inp=False, skip_load_wgt=False, skip_load_acc=False, + skip_store_out=False, skip_alu=False, skip_gemm=False): assert len(outs) == 1 output = outs[0] ewise_inputs = [] @@ -410,6 +73,19 @@ def _traverse(op): _traverse(output.op) assert len(conv2d_res) == 1 conv2d_stage = conv2d_res[0].output(0) + s = tvm.create_schedule(output.op) + + ##### space definition begin ##### + b, co, h, w, bi, ci = s[conv2d_stage].op.axis + ci, kh, kw, bci = s[conv2d_stage].op.reduce_axis + cfg.define_split('tile_b', b, num_outputs=2) + cfg.define_split('tile_h', h, num_outputs=2) + cfg.define_split('tile_w', w, num_outputs=2) + cfg.define_split('tile_ci', ci, num_outputs=2) + cfg.define_split('tile_co', co, num_outputs=2) + cfg.define_knob('oc_nthread', [1, 2]) + cfg.define_knob('h_nthread', [1, 2]) + ###### space definition end ###### data, kernel = conv2d_stage.op.input_tensors if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: @@ -418,34 +94,9 @@ def _traverse(op): data = temp else: pad_data = None - wrkld = _get_workload(data, pad_data, kernel, output) - - if wrkld in _SCHEDULE_STR_MAP and planStr is None: - planStr = _SCHEDULE_STR_MAP[wrkld] - logging.info("Apply pre-cached schedule for %s->%s", str(wrkld) , planStr) - if planStr: - matchObj = re.match( r'b(\d+)_oc(\d+)_ic(\d+)_h(\d+)_w(\d+)_oct(\d+)_ht(\d+)', planStr) - b_factor = int(matchObj.group(1)) - oc_factor = int(matchObj.group(2)) - ic_factor = int(matchObj.group(3)) - h_factor = int(matchObj.group(4)) - w_factor = int(matchObj.group(5)) - oc_nthread = int(matchObj.group(6)) - h_nthread = int(matchObj.group(7)) - plan = Schedule(b_factor=b_factor, - oc_factor=oc_factor, - ic_factor=ic_factor, - h_factor=h_factor, - w_factor=w_factor, - oc_nthread=oc_nthread, - h_nthread=h_nthread) - else: - plan = find_schedules(wrkld, vt_only=True, best_only=True)[0] - logging.info("Trying to find plan for %s", wrkld) env = get_env() mock = env.mock - load_inp = mock.dma_copy if skip_load_inp else env.dma_copy load_wgt = mock.dma_copy if skip_load_wgt else env.dma_copy load_acc = mock.dma_copy if skip_load_acc else env.dma_copy @@ -453,9 +104,8 @@ def _traverse(op): alu = mock.alu if skip_alu else env.alu gemm = mock.gemm if skip_gemm else env.gemm - # schedule1 + # schedule oshape = topi.util.get_const_tuple(output.shape) - s = tvm.create_schedule(output.op) # setup pad if pad_data is not None: @@ -465,28 +115,23 @@ def _traverse(op): cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) s[conv2d_stage].set_scope(env.acc_scope) + # cache read input cache_read_ewise = [] - for consumer, tensor in ewise_inputs: cache_read_ewise.append( s.cache_read(tensor, env.acc_scope, [consumer])) + # set ewise scope for op in ewise_ops: s[op].set_scope(env.acc_scope) s[op].pragma(s[op].op.axis[0], alu) - # tile - oc_factor = (plan.oc_factor if plan.oc_factor - else plan.out_filter // env.BLOCK_OUT) - h_factor = (plan.h_factor if plan.h_factor else oshape[2]) - w_factor = (plan.w_factor if plan.w_factor else oshape[3]) - x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis - x_co0, x_co1 = s[output].split(x_co, factor=oc_factor) - x_i0, x_i1 = s[output].split(x_i, factor=h_factor) - x_j0, x_j1 = s[output].split(x_j, factor=w_factor) + x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) + x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) + x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) store_pt = x_j0 @@ -500,14 +145,14 @@ def _traverse(op): s[tensor].pragma(s[tensor].op.axis[0], load_acc) # virtual threading along output channel axes - if plan.oc_nthread > 1: - _, v_t = s[output].split(x_co0, factor=plan.oc_nthread) + if cfg['oc_nthread'].val > 1: + _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) # virtual threading along spatial rows - if plan.h_nthread > 1: - _, v_t = s[output].split(x_i0, factor=plan.h_nthread) + if cfg['h_nthread'].val > 1: + _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) s[output].reorder(v_t, x_bo) s[output].bind(v_t, tvm.thread_axis("cthread")) @@ -515,10 +160,9 @@ def _traverse(op): k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) - if plan.ic_factor: - k_o, _ = s[conv2d_stage].split(k_o, factor=plan.ic_factor) - s[cdata].compute_at(s[conv2d_stage], k_o) - s[ckernel].compute_at(s[conv2d_stage], k_o) + k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) + s[cdata].compute_at(s[conv2d_stage], k_o) + s[ckernel].compute_at(s[conv2d_stage], k_o) # Use VTA instructions s[cdata].pragma(s[cdata].op.axis[0], load_inp) @@ -527,31 +171,3 @@ def _traverse(op): s[output].pragma(x_co1, store_out) return s -class Conv2DSchedule(object): - """ 2D convolution schedule object. - """ - def __init__(self, - b_factor=1, - oc_factor=1, - ic_factor=1, - h_factor=1, - w_factor=0, - oc_nthread=0, - h_nthread=0, - debug_sync=False): - self.b_factor = b_factor - self.oc_factor = oc_factor - self.ic_factor = ic_factor - self.h_factor = h_factor - self.w_factor = w_factor - self.oc_nthread = oc_nthread - self.h_nthread = h_nthread - self.debug_sync = debug_sync - - def __str__(self): - return "{}.{}.{}.{}.{}.{}.{}".format( - self.b_factor, self.oc_factor, self.ic_factor, - self.h_factor, self.w_factor, - self.oc_nthread, self.h_nthread) - -Schedule = Conv2DSchedule diff --git a/vta/scripts/tune_conv.py b/vta/scripts/tune_conv.py index 95f13cc54f5a3..6b1740092b160 100644 --- a/vta/scripts/tune_conv.py +++ b/vta/scripts/tune_conv.py @@ -1,168 +1,25 @@ -"""Tuning a conv2d operator """ -import tvm -import sys +"""Tuning a single conv2d operator""" import logging + +import tvm from tvm import autotvm from tvm.contrib.util import get_lower_ir import topi import vta import vta.testing -from vta.top.testing import my_clip env = vta.get_env() -def vta_build_func(measure_input, tmp_dir, **kwargs): - import time - import os - from tvm.autotvm.measure.measure_methods import BuildResult - from random import getrandbits - from tvm.autotvm.util import get_const_tuple - tic = time.time() - try: - filename = os.path.join(tmp_dir, "tmp_func_%0x.tar" % getrandbits(64)) - target, task, config = measure_input - - with target: - s, args = task.instantiate(config) - if not config.valid(): - raise InstantiationError(config.errors) - - func = vta.build(s, args, target='ext_dev', target_host=task.target_host) - - arg_info = tuple((get_const_tuple(x.shape), x.dtype) for x in args) - func.export_library(filename) - except Exception as e: # pylint: disable=broad-except - return BuildResult(None, None, e, time.time() - tic) - return BuildResult(filename, arg_info, None, time.time() - tic) - - -def schedule_packed_conv2d(cfg, outs, - skip_load_inp=False, skip_load_wgt=False, skip_load_acc=False, - skip_store_out=False, skip_alu=False, skip_gemm=False): - """Schedule the packed conv2d. - """ - assert len(outs) == 1 - output = outs[0] - ewise_inputs = [] - ewise_ops = [] - conv2d_res = [] - assert output.op.input_tensors[0].dtype == "int32" - - def _traverse(op): - if topi.tag.is_broadcast(op.tag): - if not op.same_as(output.op): - ewise_ops.append(op) - for tensor in op.input_tensors: - if isinstance(tensor.op, tvm.tensor.PlaceholderOp): - ewise_inputs.append((op, tensor)) - else: - _traverse(tensor.op) - else: - assert op.tag == "packed_conv2d" - conv2d_res.append(op) - - _traverse(output.op) - assert len(conv2d_res) == 1 - conv2d_stage = conv2d_res[0].output(0) - s = tvm.create_schedule(output.op) - - ##### space definition begin ##### - b, co, h, w, bi, ci = s[conv2d_stage].op.axis - ci, kh, kw, bci = s[conv2d_stage].op.reduce_axis - cfg.define_split('tile_b', b, num_outputs=2) - cfg.define_split('tile_h', h, num_outputs=2) - cfg.define_split('tile_w', w, num_outputs=2) - cfg.define_split('tile_ci', ci, num_outputs=2) - cfg.define_split('tile_co', co, num_outputs=2) - cfg.define_knob('oc_nthread', [1, 2]) - cfg.define_knob('h_nthread', [1, 2]) - ###### space definition end ###### - - data, kernel = conv2d_stage.op.input_tensors - if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: - temp = data.op.input_tensors[0] - pad_data = data - data = temp - else: - pad_data = None - - mock = env.mock - load_inp = mock.dma_copy if skip_load_inp else env.dma_copy - load_wgt = mock.dma_copy if skip_load_wgt else env.dma_copy - load_acc = mock.dma_copy if skip_load_acc else env.dma_copy - store_out = mock.dma_copy if skip_store_out else env.dma_copy - alu = mock.alu if skip_alu else env.alu - gemm = mock.gemm if skip_gemm else env.gemm - - # schedule - oshape = topi.util.get_const_tuple(output.shape) - - # setup pad - if pad_data is not None: - cdata = pad_data - s[pad_data].set_scope(env.inp_scope) - else: - cdata = s.cache_read(data, env.inp_scope, [conv2d_stage]) - ckernel = s.cache_read(kernel, env.wgt_scope, [conv2d_stage]) - s[conv2d_stage].set_scope(env.acc_scope) - - # cache read input - cache_read_ewise = [] - for consumer, tensor in ewise_inputs: - cache_read_ewise.append( - s.cache_read(tensor, env.acc_scope, [consumer])) - - # set ewise scope - for op in ewise_ops: - s[op].set_scope(env.acc_scope) - s[op].pragma(s[op].op.axis[0], alu) - - # tile - x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis - x_co0, x_co1 = cfg['tile_co'].apply(s, output, x_co) - x_i0, x_i1 = cfg['tile_h'].apply(s, output, x_i) - x_j0, x_j1 = cfg['tile_w'].apply(s, output, x_j) - s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) - store_pt = x_j0 - - # set all compute scopes - s[conv2d_stage].compute_at(s[output], store_pt) - for op in ewise_ops: - s[op].compute_at(s[output], store_pt) - - for tensor in cache_read_ewise: - s[tensor].compute_at(s[output], store_pt) - s[tensor].pragma(s[tensor].op.axis[0], load_acc) - - # virtual threading along output channel axes - if cfg['oc_nthread'].val > 1: - _, v_t = s[output].split(x_co0, factor=cfg['oc_nthread'].val) - s[output].reorder(v_t, x_bo) - s[output].bind(v_t, tvm.thread_axis("cthread")) - - # virtual threading along spatial rows - if cfg['h_nthread'].val > 1: - _, v_t = s[output].split(x_i0, factor=cfg['h_nthread'].val) - s[output].reorder(v_t, x_bo) - s[output].bind(v_t, tvm.thread_axis("cthread")) - - x_bo, x_co, x_i, x_j, x_bi, x_ci = s[conv2d_stage].op.axis - k_o, d_i, d_j, k_i = s[conv2d_stage].op.reduce_axis - s[conv2d_stage].reorder(x_bo, k_o, x_j, d_j, d_i, x_co, x_i, x_bi, x_ci, k_i) - - k_o, _ = cfg['tile_ci'].apply(s, conv2d_stage, k_o) - s[cdata].compute_at(s[conv2d_stage], k_o) - s[ckernel].compute_at(s[conv2d_stage], k_o) - - # Use VTA instructions - s[cdata].pragma(s[cdata].op.axis[0], load_inp) - s[ckernel].pragma(s[ckernel].op.axis[0], load_wgt) - s[conv2d_stage].tensorize(x_bi, gemm) - s[output].pragma(x_co1, store_out) - return s - -@autotvm.template +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype): data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN) kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN) @@ -172,41 +29,42 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype): OW = (W + 2 * padding[1] - KW) // strides[1] + 1 data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) - - w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) - kernel_shape_pack = kernel_shape[:-1] + (kernel_shape[-1] // w_pack_factor,) - kernel_arg = tvm.placeholder(kernel_shape_pack, dtype="int8", name="kernel_arg") - kernel = vta.reinterpret(kernel_arg, kernel_shape, dtype=env.wgt_dtype) - - res_conv = vta.top.packed_conv2d(data, kernel, padding=padding, strides=strides) - res = topi.right_shift(res_conv, 8) - res = topi.add(res, bias) - res = my_clip(res, 0, 127) - res = topi.cast(res, "int8") + bias = tvm.placeholder(bias_shape, name="bias", dtype=env.acc_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + with tvm.target.vta(): + res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, + layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32') + res = topi.add(res, bias) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) - cfg = autotvm.get_config() - s = schedule_packed_conv2d(cfg, [res]) - cfg.add_flop(2 * N * CI * OH * OW * CO * KH * KW) - return s, [data, kernel_arg, bias, res] + return s, [data, kernel, bias, res] if __name__ == '__main__': + model = env.TARGET N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \ 1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32' task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype), - target='ext_dev', target_host=env.target_host) + target=tvm.target.vta(model), target_host=env.target_host, template_key='direct') print(task.config_space) # logging config (for printing tuning log to the screen) + logging.basicConfig() logging.getLogger('autotvm').setLevel(logging.DEBUG) - logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout)) measure_option = autotvm.measure_option( - builder=autotvm.LocalBuilder(build_func=vta_build_func), - runner=autotvm.RPCRunner( - 'ultra96', 'fleet', 9190)) + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(model, 'fleet', 9190, number=4, repeat=3, timeout=30, + check_correctness=True)) tuner = autotvm.tuner.RandomTuner(task) tuner.tune(n_trial=len(task.config_space), diff --git a/vta/scripts/tune_resnet.py b/vta/scripts/tune_resnet.py new file mode 100644 index 0000000000000..f3a64323723e7 --- /dev/null +++ b/vta/scripts/tune_resnet.py @@ -0,0 +1,224 @@ +import os +import argparse +import time + +import numpy as np + +import tvm +from tvm import rpc, autotvm +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.contrib import graph_runtime, util +from tvm.contrib.download import download + +import topi +import nnvm.compiler +import vta +import vta.testing + +env = vta.get_env() + +def register_vta_tuning_tasks(): + from tvm.autotvm.task.nnvm_integration import TaskExtractEnv, deserialize_args + + @tvm.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + # init old env + TaskExtractEnv([]) + + @autotvm.task.register("topi_nn_conv2d", override=True) + def _topi_nn_conv2d(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + A, W = args[:2] + + with tvm.target.vta(): + res = topi.nn.conv2d(*args, **kwargs) + res = topi.right_shift(res, 8) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.current_target().device_name == 'vta': + s = topi.generic.schedule_conv2d_nchw([res]) + else: + s = tvm.create_schedule([res.op]) + return s, [A, W, res] + + +def generate_graph(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + # Compile NNVM graph + with nnvm.compiler.build_config(opt_level=3): + with vta.build_config(): + graph, lib, params = nnvm.compiler.build( + sym, target, shape_dict, dtype_dict, + params=params, target_host=target_host) + + return graph, lib, params + + +def extract_tasks(sym, params, target, target_host): + # Populate the shape and data type dictionary + shape_dict = {"data": (1, 3, 224, 224)} + dtype_dict = {"data": 'float32'} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Apply NNVM graph optimization passes + sym = vta.graph.clean_cast(sym) + sym = vta.graph.clean_conv_fuse(sym) + assert env.BLOCK_IN == env.BLOCK_OUT + sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT) + + with vta.build_config(): + tasks = autotvm.task.extract_from_graph(sym, target=target, target_host=target_host, + shape=shape_dict, dtype=dtype_dict, symbols=(nnvm.sym.conv2d,)) + return tasks + + +def download_model(): + url = "https://github.com/uwsaml/web-data/raw/master/vta/models/" + categ_fn = 'synset.txt' + graph_fn = 'resnet18_qt8.json' + params_fn = 'resnet18_qt8.params' + data_dir = '_data' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + + for file in [categ_fn, graph_fn, params_fn]: + if not os.path.isfile(file): + download(os.path.join(url, file), os.path.join(data_dir, file)) + + sym = nnvm.graph.load_json(open(os.path.join(data_dir, graph_fn)).read()) + params = nnvm.compiler.load_param_dict(open(os.path.join(data_dir, params_fn), 'rb').read()) + + return sym, params + + +def tune_tasks(tasks, + measure_option, + tuner='xgb', + n_trial=1000, + early_stopping=None, + log_filename='tuning.log', + use_transfer_learning=True, + try_winograd=True): + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i+1, len(tasks)) + + # create tuner + if tuner == 'xgb' or tuner == 'xgb-rank': + tuner_obj = XGBTuner(tsk, loss_type='rank') + elif tuner == 'ga': + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == 'random': + tuner_obj = RandomTuner(tsk) + elif tuner == 'gridsearch': + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + n_trial_ = min(n_trial, len(tsk.config_space)) + tuner_obj.tune(n_trial_, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(n_trial_, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file)]) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + +if __name__ == '__main__': + device_key = env.TARGET + + tuning_opt = { + 'log_filename': 'resnet-18.log', + + 'tuner': 'random', + 'n_trial': 1e9, + 'early_stopping': None, + + 'measure_option': autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func), + runner=autotvm.RPCRunner(device_key, 'fleet', 9190, + number=4, repeat=3, timeout=60, + check_correctness=True)) + } + + # download model + sym, params = download_model() + + # extract tasks + register_vta_tuning_tasks() + + print("Extract tasks...") + target = tvm.target.vta(device_key) + target_host = env.target_host + tasks = extract_tasks(sym, params, target, target_host) + + print("Tuning...") + tune_tasks(tasks, **tuning_opt) + + # compile kernels with history best records + with autotvm.tophub.context(target, extra_files=[tuning_opt['log_filename']]): + print("Compile...") + graph, lib, params = generate_graph(sym, params, target, target_host) + input_shape = (1, 3, 224, 224) + dtype = 'float32' + + # export library + tmp = util.tempdir() + filename = "net.tar" + lib.export_library(tmp.relpath(filename)) + + # upload module to device + print("Upload...") + remote = autotvm.measure.request_remote(device_key, 'fleet', 9190, timeout=10000) + remote.upload(tmp.relpath(filename)) + rlib = remote.load_module(filename) + + # upload parameters to device + ctx = remote.context(str(target), 0) + rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} + data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) + module = graph_runtime.create(graph, rlib, ctx) + module.set_input('data', data_tvm) + module.set_input(**rparams) + + # evaluate + print("Evaluate inference time cost...") + ftimer = module.module.time_evaluator("run", ctx, number=3, repeat=3) + prof_res = np.array(ftimer().results) * 1000 # convert to millisecond + print("Mean inference time (std dev): %.2f ms (%.2f ms)" % + (np.mean(prof_res), np.std(prof_res))) + diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index e159cdfe77a5f..1c377fea61188 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -901,12 +901,12 @@ class CommandQueue { insn_queue_.InitSpace(); device_ = VTADeviceAlloc(); CHECK(device_ != nullptr); - printf("Initialize VTACommandHandle...\n"); + //printf("Initialize VTACommandHandle...\n"); } ~CommandQueue() { VTADeviceFree(device_); - printf("Close VTACommandhandle...\n"); + //printf("Close VTACommandhandle...\n"); } uint32_t GetElemBytes(uint32_t memory_id) { diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index fb3b350f129b9..21dfe26974038 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -1,5 +1,11 @@ """Testing if we can generate code in topi style""" +import os +import json +from collections import namedtuple + +import numpy as np + import tvm from tvm import autotvm from tvm.contrib import util @@ -8,155 +14,333 @@ import topi.testing import vta import vta.testing -from vta.top.testing import run_vta_conv2d, my_clip -import numpy as np -import os -import json -Workload = vta.top.vta_conv2d.Workload +Workload = namedtuple("Conv2DWorkload", + ['batch', 'height', 'width', 'in_filter', 'out_filter', + 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) +# ResNet18 workloads +resnet_wkls = [ + # Workloads of resnet18 on imagenet + # Workload(env.BATCH, 224, 224, env.BLOCK_IN, 64, 7, 7, 3, 3, 2, 2), + ('resnet-18.C2', Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1)), + #('resnet-18.C3', Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1)), # this is a workload from a wrong model + ('resnet-18.C4', Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C5', Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C6', Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C7', Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C8', Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C9', Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1)), + ('resnet-18.C10', Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2)), + ('resnet-18.C11', Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2)), + ('resnet-18.C12', Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1)), +] -def test_cpu_conv2d(): - def run_cpu_conv2d(env, remote, key, batch_size, wl): - data_shape = (batch_size, wl.in_filter, wl.height, wl.width) - kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) +def run_cpu_conv2d(env, remote, wl, target): + data_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + kernel_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + + with target: res_conv = topi.nn.conv2d( - data, kernel, padding=(wl.hpad, wl.wpad), - strides=(wl.hstride, wl.wstride), - out_dtype="int32") + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), "NCHW", "int32") res = topi.right_shift(res_conv, 8) res = my_clip(res, 0, 127) res = topi.cast(res, "int8") + s = topi.generic.schedule_conv2d_nchw([res]) - # To compute number of ops, use a x2 factor for FMA - num_ops = 2 * batch_size * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + # To compute number of ops, use a x2 factor for FMA + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter - a_shape = (batch_size, wl.in_filter, wl.height, wl.width) - w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) - stride = (wl.hstride, wl.wstride) - data_dtype = data.dtype - kernel_dtype = kernel.dtype - acc_dtype = env.acc_dtype + # get reference data + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + stride = (wl.hstride, wl.wstride) + data_dtype = data.dtype + kernel_dtype = kernel.dtype + acc_dtype = env.acc_dtype + padding = wl.hpad + + @memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = topi.testing.conv2d_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) + return a_np, w_np, b_np + + check_correctness = True + print_ir = False + + # build + mod = tvm.build(s, [data, kernel, res], target=target, target_host=env.target_host, name="conv2d") + temp = util.tempdir() + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + + # run + ctx = remote.context(str(target)) + data_orig, kernel_orig, res_ref = get_ref_data() + res_shape = topi.util.get_const_tuple(res.shape) + res_np = np.zeros(res_shape).astype(res.dtype) + + data_arr = tvm.nd.array(data_orig, ctx) + kernel_arr = tvm.nd.array(kernel_orig, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=4) + cost = time_f(data_arr, kernel_arr, res_arr) + res_unpack = res_arr.asnumpy() + + # verify + if check_correctness: assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) padding = wl.hpad + res_ref = res_ref >> 8 + res_ref = np.clip(res_ref, 0, 127) + res_ref = res_ref.astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) - @memoize("vta.tests.test_benchmark_topi.conv2d.cpu.verify_nhwc") - def get_ref_data(): - a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) - w_np = (np.random.uniform(size=w_shape) * 4).astype(kernel_dtype) - a_np = np.abs(a_np) - w_np = np.abs(w_np) - b_np = topi.testing.conv2d_nchw_python( - a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) - return a_np, w_np, b_np - - - def verify(s, check_correctness): - mod = tvm.build(s, [data, kernel, res], - target_host=env.target_host, - name="conv2d") - temp = util.tempdir() - mod.save(temp.relpath("conv2d.o")) - remote.upload(temp.relpath("conv2d.o")) - f = remote.load_module("conv2d.o") - # verify - ctx = remote.cpu(0) - # Data in original format - data_orig, kernel_orig, res_ref = get_ref_data() - res_shape = topi.util.get_const_tuple(res.shape) - res_np = np.zeros(res_shape).astype(res.dtype) - data_arr = tvm.nd.array(data_orig, ctx) - kernel_arr = tvm.nd.array(kernel_orig, ctx) - res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d", ctx, number=5) - cost = time_f(data_arr, kernel_arr, res_arr) - res_unpack = res_arr.asnumpy() - if check_correctness: - assert wl.hpad == wl.wpad - stride = (wl.hstride, wl.wstride) - padding = wl.hpad - res_ref = res_ref >> 8 - res_ref = np.clip(res_ref, 0, 127) - res_ref = res_ref.astype("int8") - np.testing.assert_allclose(res_unpack, res_ref) - return cost - - def conv_normal(print_ir): - print("----- CONV2D CPU End-to-End Test -------") - s = topi.generic.schedule_conv2d_nchw([res]) - if print_ir: - print(tvm.lower(s, [data, kernel, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) - - conv_normal(False) + if print_ir: + print(tvm.lower(s, [data, kernel, res], simple_mode=True)) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("CPU TEST: Time cost = %g sec/op, %g GOPS\n" % (cost.mean, gops)) + +def test_cpu_conv2d(): def _run(env, remote): - # ResNet18 workloads - resnet = { - # Workloads of resnet18 on imagenet - 0: Workload(1, 224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(1, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(1, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(1, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(1, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(1, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(1, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(1, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(1, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(1, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(1, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(1, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - } - batch_size = 1 - for i in range(1, len(resnet)): - wl = resnet[i] - key = "resnet-cfg[%d]" % i - print("key=%s" % key) - print(wl) - with tvm.target.create(env.target_vta_cpu): - run_cpu_conv2d(env, remote, key, batch_size, wl) - - # load pre-tuned operator parameters for ARM CPU - with autotvm.tophub.context('llvm -device=arm_cpu'): - vta.testing.run(_run) + target = env.target_vta_cpu + + with autotvm.tophub.context(target): # load pre-tuned operator parameters for ARM CPU + for name, wl in resnet_wkls: + print(name, wl) + run_cpu_conv2d(env, remote, wl, target) + + vta.testing.run(_run) + + +def _sign_extend(value, bits): + sign_bit = 1 << (bits - 1) + return (value & (sign_bit - 1)) - (value & sign_bit) +_vector_sign_extend = np.vectorize(_sign_extend) + + +def _pack(x, width): + assert(len(x.shape)==6) + assert(x.dtype=="int8") + pack_factor = 8 // width + mask = ((1 << width) - 1) + + s = x.shape + s_reshape = s[:-1] + (s[-1] // pack_factor, pack_factor) + s_pack = s[:-1] + (s[-1] // pack_factor,) + x_reshape = x.reshape(s_reshape) + x_packed = np.zeros(s_pack, dtype="int8") + for i in range(0, pack_factor): + x_packed |= (x_reshape[:,:,:,:,:,:,i] & mask) << (i * width) + + return x_packed + +def _unpack(x, width): + assert(len(x.shape)==6) + assert(x.dtype=="int8") + pack_factor = 8 // width + mask = ((1 << width) - 1) + + s = x.shape + x_unpack = np.zeros(s[:] + (pack_factor,), dtype=x.dtype) + for i in range(0, pack_factor): + x_unpack[:,:,:,:,:,:,i] = _vector_sign_extend(((x >> (i * width)) & mask), width) + + return x_unpack.reshape(s[:-1] + (s[-1] * pack_factor,)) + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +def run_vta_conv2d(env, remote, wl, target, check_correctness=True, print_ir=False, + plan_str=None, samples=4, profileOnly=False, + skip_load_inp=False, skip_load_wgt=False, skip_load_acc=False, + skip_store_out=False, skip_alu=False, skip_gemm=False): + + data_shape = (wl.batch//env.BATCH, wl.in_filter//env.BLOCK_IN, + wl.height, wl.width, env.BATCH, env.BLOCK_IN) + kernel_shape = (wl.out_filter//env.BLOCK_OUT, wl.in_filter//env.BLOCK_IN, + wl.hkernel, wl.wkernel, env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (wl.batch//env.BATCH, wl.out_filter//env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) + + # Handle quantized inputs (less than 8 bits) + # x_pack_factor = 1 << (3 - env.LOG_INP_WIDTH) + # data_shape_pack = data_shape[:-1] + (data_shape[-1]//x_pack_factor,) + # data_arg = tvm.placeholder( + # data_shape_pack, + # dtype="int8", name="data_arg") + # data = vta.reinterpret(data_arg, data_shape, dtype=env.inp_dtype) + + # Handle quantized kernels (less than 8 bits) + w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH) + kernel_shape_pack = kernel_shape[:-1] + (kernel_shape[-1]//w_pack_factor,) + kernel_arg = tvm.placeholder(kernel_shape_pack, dtype="int8", name="kernel_arg") + kernel = vta.reinterpret(kernel_arg, kernel_shape, dtype=env.wgt_dtype) + + with target: + res_conv = topi.nn.conv2d( + data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), + "NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN), 'int32') + res = topi.right_shift(res_conv, 8) + res = topi.add(res, bias) + res = my_clip(res, 0, (1 << env.OUT_WIDTH-1)-1) + res = topi.cast(res, "int8") + + s = topi.generic.schedule_conv2d_nchw([res]) + #planStr=plan_str, + #skip_load_inp=skip_load_inp, + #skip_load_wgt=skip_load_wgt, + #skip_load_acc=skip_load_acc, + #skip_store_out=skip_store_out, + #skip_alu=skip_alu, + #skip_gemm=skip_gemm) + if print_ir: + print(vta.lower(s, [data, kernel_arg, bias, res], simple_mode=True)) + + # Handle quantized outputs (less than 8 bits) + # o_pack_factor = 1 << (3 - env.LOG_OUT_WIDTH) + res_shape = topi.util.get_const_tuple(res.shape) + # res_shape_pack = res_shape[:-1] + (res_shape[-1]//o_pack_factor,) + # res_arg = vta.reinterpret(res, res_shape_pack, dtype="int8") + + # generate referene data + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + a_shape = (wl.batch, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + stride = (wl.hstride, wl.wstride) + data_dtype, kernel_dtype, acc_dtype = data.dtype, kernel.dtype, env.acc_dtype + padding = wl.hpad + INP_WIDTH, WGT_WIDTH = env.INP_WIDTH, env.WGT_WIDTH + + @memoize("vta.tests.test_benchmark_topi.conv2d.verify_nhwc") + def get_ref_data(): + # derive min max for input and weight types (max non inclusive) + a_min, a_max = 0 - (1 << (INP_WIDTH - 1)), (1 << (INP_WIDTH - 1)) + w_min, w_max = 0 - (1 << (WGT_WIDTH - 1)), (1 << (WGT_WIDTH - 1)) + a_np = np.random.randint(a_min, a_max, size=a_shape).astype("int8") + w_np = np.random.randint(w_min, w_max, size=w_shape).astype("int8") + b_np = topi.testing.conv2d_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) + return a_np, w_np, b_np + + mod = vta.build(s, + [data, kernel_arg, bias, res], + target, + env.target_host, name="conv2d") + + # Data in original format + data_orig, kernel_orig, res_ref = get_ref_data() + bias_orig = (np.random.uniform(size=(wl.batch, wl.out_filter,)) * (1 << (env.INP_WIDTH + env.WGT_WIDTH - 2))) + bias_orig = bias_orig.astype("int32") + bias_orig = np.abs(bias_orig) + + data_packed = data_orig.reshape( + wl.batch//env.BATCH, env.BATCH, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 2, 4, 5, 1, 3)) + kernel_packed = kernel_orig.reshape( + wl.out_filter//env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_filter//env.BLOCK_IN, env.BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_packed = bias_orig.reshape( + wl.batch // env.BATCH, wl.out_filter // env.BLOCK_OUT, + 1, 1, env.BATCH, env.BLOCK_OUT) + + # Quantized packing + data_qpacked = _pack(data_packed, env.INP_WIDTH) + kernel_qpacked = _pack(kernel_packed, env.WGT_WIDTH) + + # Upload + temp = util.tempdir() + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + + ctx = remote.context(str(target)) + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_qpacked, ctx) + kernel_arr = tvm.nd.array(kernel_qpacked, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=samples) + + # In sim mode, collect simulator runtime statistics + stats = {} + cost = None + if env.TARGET == "sim": + # Check if we're in local RPC mode (allows us to rebuild the + # runtime on the fly when varying the VTA designs) + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: + remote.get_function("vta.simulator.profiler_clear")() + if profileOnly: + remote.get_function("vta.simulator.profiler_debug_mode")(1) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = json.loads(remote.get_function("vta.simulator.profiler_status")()) + else: + simulator.clear_stats() + if profileOnly: + simulator.debug_mode(1) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + stats = simulator.stats() + else: + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + + # Check correctness + correct = False + if check_correctness: + res_unpack = res_arr.asnumpy() + res_unpack = _unpack(res_unpack.astype("int8"), env.OUT_WIDTH) + res_unpack = res_unpack.transpose( + (0, 4, 1, 5, 2, 3)).reshape(wl.batch, wl.out_filter, fout_height, fout_width) + assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) + padding = wl.hpad + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, (1 << env.OUT_WIDTH-1)-1) + res_ref = res_ref.astype("int8") + correct = np.allclose(res_unpack, res_ref) + + gops = (num_ops / cost.mean) / float(10 ** 9) + print("VTA TEST: Time cost = %g sec/op, %g GOPS\n" % (cost.mean, gops)) + + return correct, cost, stats + def test_vta_conv2d(): def _run(env, remote): - # ResNet18 workloads - resnet = { - # Workloads of resnet18 on imagenet - 0: Workload(env.BATCH, 224, 224, env.BLOCK_IN, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(env.BATCH, 56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(env.BATCH, 56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(env.BATCH, 56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(env.BATCH, 56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(env.BATCH, 28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(env.BATCH, 28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(env.BATCH, 28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(env.BATCH, 14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(env.BATCH, 14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(env.BATCH, 14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(env.BATCH, 7, 7, 512, 512, 3, 3, 1, 1, 1, 1), - } - - for i in range(0, len(resnet)): - wl = resnet[i] - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - num_ops = 2 * wl.batch * fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter - print("----- CONV2D End-to-End Test -------") - print(wl) - (correct, cost, stats) = run_vta_conv2d(env, remote, wl) - assert(correct) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GOPS" % (cost.mean, gops)) + target = env.target + + with autotvm.tophub.context(target): # load pre-tuned operator parameters for ARM CPU + for name, wl in resnet_wkls: + print(name, wl) + run_vta_conv2d(env, remote, wl, target) vta.testing.run(_run) @@ -164,3 +348,4 @@ def _run(env, remote): if __name__ == "__main__": # test_cpu_conv2d() test_vta_conv2d() + diff --git a/vta/tutorials/resnet.py b/vta/tutorials/resnet.py index 2900b383ac8c3..faa24e137c8d5 100644 --- a/vta/tutorials/resnet.py +++ b/vta/tutorials/resnet.py @@ -59,7 +59,7 @@ def process_image(image): # Takes in the graph runtime, and an image, and returns top result and time def classify(m, image): m.set_input('data', image) - timer = m.module.time_evaluator("run", ctx, number=1) + timer = m.module.time_evaluator("run", ctx, number=3) tcost = timer() tvm_output = m.get_output(0) top = np.argmax(tvm_output.asnumpy()[0]) @@ -108,7 +108,6 @@ def generate_graph(graph_fn, params_fn, target): params=params, target_host=target_host) # Save the compiled inference graph library - assert tvm.module.enabled("rpc") temp = util.tempdir() lib.save(temp.relpath("graphlib.o")) @@ -193,15 +192,11 @@ def generate_graph(graph_fn, params_fn, target): # ------------------------ # Build the ResNet graph runtime, and configure the parameters. -# llvm_command = 'llvm -device=arm_cpu -model=pynq {}'.format(env.llvm_triple) # run arm cpu on pynq -# llvm_command = 'llvm -device=arm_cpu -model=ultra96 {}'.format(env.llvm_triple) # run arm cpu on ultra96 -# llvm_command = 'llvm -device=vta -model=pynq {}'.format(env.llvm_triple) # run vta cpu on pynq -llvm_command = 'llvm -device=vta -model=ultra96 {}'.format(env.llvm_triple) # run vta cpu on ultra96 - -target = tvm.target.create(llvm_command) +#target = tvm.target.arm_cpu(env.TARGET) # run on arm cpu +target = tvm.target.vta(env.TARGET) # Device context -ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) +ctx = remote.context(str(target)) # Build the graph runtime graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn), @@ -231,7 +226,7 @@ def generate_graph(graph_fn, params_fn, target): m.set_input('data', image) # Perform inference -timer = m.module.time_evaluator("run", ctx, number=1) +timer = m.module.time_evaluator("run", ctx, number=4) tcost = timer() # Get classification results