Skip to content

Commit

Permalink
[RUNTIME] Simplify dynamic library and code path. (#27)
Browse files Browse the repository at this point in the history
* [RUNTIME] Simplify dynamic library and code path.

* reword the readme
  • Loading branch information
tqchen committed Jul 12, 2018
1 parent 63b3312 commit 5481665
Show file tree
Hide file tree
Showing 13 changed files with 260 additions and 213 deletions.
16 changes: 6 additions & 10 deletions vta/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,37 +53,33 @@ else
NO_WHOLE_ARCH= --no-whole-archive
endif


all: lib/libvta.so lib/libvta_runtime.so

VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)

ifeq ($(TARGET), VTA_PYNQ_TARGET)
ifeq ($(VTA_TARGET), pynq)
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
LDFLAGS += -L/usr/lib -lsds_lib
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
LDFLAGS += -l:libdma.so
endif

ifeq ($(TARGET), sim)
ifeq ($(VTA_TARGET), sim)
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
endif

VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))

all: lib/libvta.so

build/%.o: src/%.cc
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@

lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ))
lib/libvta.so: $(VTA_LIB_OBJ)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)

lib/libvta_runtime.so: build/runtime.o
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)

lint: pylint cpplint

Expand Down
24 changes: 19 additions & 5 deletions vta/README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
Open Hardware/Software Stack for Vertical Deep Learning System Optimization
==============================================
VTA: Open, Modular, Deep Learning Accelerator Stack
===================================================

[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)

VTA is an open hardware/software co-design stack for deep learning systems systems.
It provides a customizable hardware accelerator template for deep learning inference workloads,
combined with a fully functional compiler stack built with TVM.
VTA(versatile tensor accelerator) is an open-source deep learning accelerator stack.
It is not just an open-source hardware, but is an end to end solution that includes
the entire software stack on top of VTA open-source hardware.


The key features include:

- Generic, modular open-source hardware
- Streamlined workflow to deploy to FPGAs.
- Simulator support
- Driver and JIT runtime for both simulated backend and FPGA.
- End to end TVM stack integration
- Direct optimization and deploy models from deep learning frameworks via TVM stack.
- Customized and extendible TVM compiler backend
- Flexible RPC support to ease the deployment, you can program it with python :)

VTA is part of our effort on [TVM Stack](http://www.tvmlang.org/).

License
-------
Expand Down
7 changes: 4 additions & 3 deletions vta/make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ ADD_LDFLAGS=
# the additional compile flags you want to add
ADD_CFLAGS=

# the hardware target
TARGET = pynq
# the hardware target, can be [sim, pynq]
VTA_TARGET = pynq

#---------------------
# VTA hardware parameters
Expand Down Expand Up @@ -88,7 +88,8 @@ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_A
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )

# Update ADD_CFLAGS
ADD_CFLAGS += \
ADD_CFLAGS +=
-DVTA_TARGET=$(VTA_TARGET)\
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \
Expand Down
3 changes: 2 additions & 1 deletion vta/python/vta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""TVM-based VTA Compiler Toolchain"""
from __future__ import absolute_import as _abs
import sys

from .environment import get_env, Environment

Expand All @@ -10,5 +11,5 @@
from .rpc_client import reconfig_runtime, program_fpga

from . import graph
except ImportError:
except (ImportError, RuntimeError):
pass
65 changes: 55 additions & 10 deletions vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
from __future__ import absolute_import as _abs

import os
import glob
import copy

try:
# Allow missing import in config mode.
import tvm
from . import intrin
except ImportError:
pass
import tvm
from . import intrin


class DevContext(object):
Expand Down Expand Up @@ -65,6 +61,45 @@ def get_task_qid(self, qid):
return 1 if self.DEBUG_NO_SYNC else qid


class PkgConfig(object):
"""Simple package config tool for VTA.
This is used to provide runtime specific configurations.
"""
def __init__(self, env):
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
# include path
self.include_path = [
"-I%s/include" % proj_root,
"-I%s/nnvm/tvm/include" % proj_root,
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
"-I%s/nnvm/dmlc-core/include" % proj_root
]
# List of source files that can be used to build standalone library.
self.lib_source = []
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET))
# macro keys
self.macro_defs = []
for key in env.cfg_keys:
self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key))))

if env.TARGET == "pynq":
self.ldflags = [
"-L/usr/lib",
"-lsds_lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libdma.so"]
else:
self.ldflags = []

@property
def cflags(self):
return self.include_path + self.macro_defs


class Environment(object):
"""Hareware configuration object.
Expand Down Expand Up @@ -160,6 +195,7 @@ def __init__(self, cfg):
self.mock_mode = False
self._mock_env = None
self._dev_ctx = None
self._pkg_config = None

@property
def dev(self):
Expand All @@ -168,6 +204,13 @@ def dev(self):
self._dev_ctx = DevContext(self)
return self._dev_ctx

@property
def pkg_config(self):
"""PkgConfig instance"""
if self._pkg_config is None:
self._pkg_config = PkgConfig(self)
return self._pkg_config

@property
def mock(self):
"""A mock version of the Environment
Expand Down Expand Up @@ -249,7 +292,7 @@ def mem_info_wgt_buffer():
head_address=None)

@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
def mem_info_out_buffer():
def mem_info_acc_buffer():
spec = get_env()
return tvm.make.node("MemoryInfo",
unit_bits=spec.ACC_ELEM_BITS,
Expand All @@ -265,13 +308,15 @@ def coproc_sync(op):
"int32", "VTASynchronize",
get_env().dev.command_handle, 1<<31)


@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
def coproc_dep_push(op):
return tvm.call_extern(
"int32", "VTADepPush",
get_env().dev.command_handle,
op.args[0], op.args[1])


@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
def coproc_dep_pop(op):
return tvm.call_extern(
Expand All @@ -288,7 +333,6 @@ def _init_env():

for k in Environment.cfg_keys:
keys.add("VTA_" + k)
keys.add("TARGET")

if not os.path.isfile(filename):
raise RuntimeError(
Expand All @@ -303,8 +347,9 @@ def _init_env():
val = line.split("=")[1].strip()
if k.startswith("VTA_"):
k = k[4:]
try:
cfg[k] = int(val)
else:
except ValueError:
cfg[k] = val
return Environment(cfg)

Expand Down
53 changes: 31 additions & 22 deletions vta/python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,46 @@
import os
import ctypes
import tvm
from tvm._ffi.base import c_str
from tvm.contrib import rpc, cc

from ..environment import get_env


@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
def server_start():
"""callback when server starts."""
"""VTA RPC server extension."""
# pylint: disable=unused-variable
curr_path = os.path.dirname(
os.path.abspath(os.path.expanduser(__file__)))
dll_path = os.path.abspath(
os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
os.path.join(curr_path, "../../../lib/libvta.so"))
runtime_dll = []
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")

@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
def load_vta_dll():
"""Try to load vta dll"""
if not runtime_dll:
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
logging.info("Loading VTA library: %s", dll_path)
return runtime_dll[0]

@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
def load_module(file_name):
load_vta_dll()
return _load_module(file_name)

@tvm.register_func("device_api.ext_dev")
def ext_dev_callback():
load_vta_dll()
return tvm.get_global_func("device_api.ext_dev")()

@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
path = tvm.get_global_func("tvm.contrib.rpc.server.workpath")(file_name)
load_vta_dll().VTAProgram(c_str(path))
logging.info("Program FPGA with %s", file_name)

@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
def server_shutdown():
if runtime_dll:
Expand All @@ -47,17 +67,15 @@ def reconfig_runtime(cflags):
if runtime_dll:
raise RuntimeError("Can only reconfig in the beginning of session...")
cflags = cflags.split()
env = get_env()
cflags += ["-O2", "-std=c++11"]
cflags += env.pkg_config.include_path
ldflags = env.pkg_config.ldflags
lib_name = dll_path
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
runtime_source = os.path.join(proj_root, "src/runtime.cc")
cflags += ["-I%s/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
logging.info("Rebuild runtime dll with %s", str(cflags))
cc.create_shared(lib_name, [runtime_source], cflags)
source = env.pkg_config.lib_source
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
dll_path, str(cflags), str(source), str(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags)


def main():
Expand All @@ -75,14 +93,6 @@ def main():
help="Report to RPC tracker")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))

libs = []
for file_name in [lib_path]:
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
logging.info("Load additional library %s", file_name)

if args.tracker:
url, port = args.tracker.split(":")
Expand All @@ -99,7 +109,6 @@ def main():
args.port_end,
key=args.key,
tracker_addr=tracker_addr)
server.libs += libs
server.proc.join()

if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion vta/python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ def _load_lib():
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dll_path = [
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
]
runtime_dll = []
if not all(os.path.exists(f) for f in dll_path):
Expand Down
14 changes: 11 additions & 3 deletions vta/python/vta/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,18 @@ def run(run_func):
run_func : function(env, remote)
"""
env = get_env()
# run on simulator
if simulator.enabled():

# Run on local sim rpc if necessary
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
env.TARGET = "sim"
run_func(env, rpc.LocalSession())
remote = rpc.connect("localhost", local_rpc)
run_func(env, remote)
else:
# run on simulator
if simulator.enabled():
env.TARGET = "sim"
run_func(env, rpc.LocalSession())

# Run on PYNQ if env variable exists
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
Expand Down
Loading

0 comments on commit 5481665

Please sign in to comment.