Skip to content

Commit

Permalink
Merge pull request apache#2 from weberlo/tutorial-2
Browse files Browse the repository at this point in the history
Add components for tutorial 3
  • Loading branch information
areusch committed Jul 30, 2020
2 parents 2470ca1 + 8b91ac3 commit f9642e8
Show file tree
Hide file tree
Showing 23 changed files with 1,184 additions and 191 deletions.
18 changes: 18 additions & 0 deletions include/tvm/runtime/crt/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ extern "C" {
*/
void __attribute__((noreturn)) TVMPlatformAbort(int code);

/*! \brief Start a device timer.
*
* The device timer used must not be running.
*
* \return An error code.
*/
int TVMPlatformTimerStart();

/*! \brief Stop the running device timer and get the elapsed time (in microseconds).
*
* The device timer used must be running.
*
* \param res_us Pointer to write elapsed time into.
*
* \return An error code.
*/
int TVMPlatformTimerStop(double* res_us);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_DUMP_PATH_PREFIX = "_tvmdbg_"


def create(graph_json_str, libmod, ctx, dump_root=None):
def create(graph_json_str, libmod, ctx, dump_root=None, number=10, repeat=1, min_repeat_ms=1):
"""Create a runtime executor module given a graph and module.
Parameters
Expand Down Expand Up @@ -72,7 +72,8 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root,
number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)


class GraphModuleDebug(graph_runtime.GraphModule):
Expand All @@ -99,13 +100,17 @@ class GraphModuleDebug(graph_runtime.GraphModule):
None will make a temp folder in /tmp/tvmdbg<rand_string> and does the dumping
"""

def __init__(self, module, ctx, graph_json_str, dump_root):
def __init__(self, module, ctx, graph_json_str, dump_root,
number, repeat, min_repeat_ms):
self._dump_root = dump_root
self._dump_path = None
self._get_output_by_layer = module["get_output_by_layer"]
self._run_individual = module["run_individual"]
graph_runtime.GraphModule.__init__(self, module)
self._create_debug_env(graph_json_str, ctx)
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms

def _format_context(self, ctx):
return str(ctx[0]).upper().replace("(", ":").replace(")", "")
Expand Down Expand Up @@ -180,7 +185,8 @@ def _run_debug(self):
"""
self.debug_datum._time_list = [
[float(t) * 1e-6] for t in self.run_individual(10, 1, 1)
[float(t) * 1e-6] for t in
self.run_individual(self.number, self.repeat, self.min_repeat_ms)
]
for i, node in enumerate(self.debug_datum.get_graph_nodes()):
num_outputs = self.debug_datum.get_graph_node_output_num(node)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""MicroTVM module for bare-metal backends"""

from .artifact import Artifact
from .build import build_static_runtime, DefaultOptions, TVM_ROOT_DIR, Workspace
from .build import build_static_runtime, DefaultOptions, TVM_ROOT_DIR, CRT_ROOT_DIR, Workspace
from .compiler import Compiler, DefaultCompiler, Flasher
from .debugger import GdbRemoteDebugger, RpcDebugger
from .micro_library import MicroLibrary
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
configuring the passes and scripting them in Python.
"""
from tvm.ir import IRModule
# TODO(weberlo) remove when we port dtype collectors to C++
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.type_functor import TypeVisitor

from . import _ffi_api
from .feature import Feature
Expand Down Expand Up @@ -219,6 +222,50 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)


class TyDtypeCollector(TypeVisitor):
"""Pass that collects data types used in the visited type."""

def __init__(self):
TypeVisitor.__init__(self)
self.dtypes = set()

def visit_tensor_type(self, tt):
self.dtypes.add(tt.dtype)


class ExprDtypeCollector(ExprVisitor):
"""Pass that collects data types used in all types in the visited expression."""

def __init__(self):
ExprVisitor.__init__(self)
self.ty_visitor = TyDtypeCollector()

def visit(self, expr):
if hasattr(expr, 'checked_type'):
self.ty_visitor.visit(expr.checked_type)
elif hasattr(expr, 'type_annotation'):
self.ty_visitor.visit(expr.type_annotation)
ExprVisitor.visit(self, expr)


def all_dtypes(expr):
"""Collect set of all data types used in `expr`.
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
ret : Set[String]
Set of data types used in the expression
"""
dtype_collector = ExprDtypeCollector()
dtype_collector.visit(expr)
return dtype_collector.ty_visitor.dtypes


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Expand Down
Loading

0 comments on commit f9642e8

Please sign in to comment.