From 5481665381ff31fb6f8323e7d3f17801435c976d Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 13 Apr 2018 22:05:09 -0700 Subject: [PATCH] [RUNTIME] Simplify dynamic library and code path. (#27) * [RUNTIME] Simplify dynamic library and code path. * reword the readme --- vta/Makefile | 16 +-- vta/README.md | 24 +++- vta/make/config.mk | 7 +- vta/python/vta/__init__.py | 3 +- vta/python/vta/environment.py | 65 ++++++++-- vta/python/vta/exec/rpc_server.py | 53 +++++---- vta/python/vta/testing/simulator.py | 1 - vta/python/vta/testing/util.py | 14 ++- vta/src/data_buffer.cc | 44 ------- vta/src/data_buffer.h | 90 -------------- .../{tvm/vta_device_api.cc => device_api.cc} | 29 ++--- vta/src/runtime.cc | 112 +++++++++++++++++- vta/tests/python/unittest/test_vta_insn.py | 15 +++ 13 files changed, 260 insertions(+), 213 deletions(-) delete mode 100644 vta/src/data_buffer.cc delete mode 100644 vta/src/data_buffer.h rename vta/src/{tvm/vta_device_api.cc => device_api.cc} (75%) diff --git a/vta/Makefile b/vta/Makefile index 6bfa82dc2e10..e112b39c9f58 100644 --- a/vta/Makefile +++ b/vta/Makefile @@ -53,12 +53,9 @@ else NO_WHOLE_ARCH= --no-whole-archive endif - -all: lib/libvta.so lib/libvta_runtime.so - VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc) -ifeq ($(TARGET), VTA_PYNQ_TARGET) +ifeq ($(VTA_TARGET), pynq) VTA_LIB_SRC += $(wildcard src/pynq/*.cc) LDFLAGS += -L/usr/lib -lsds_lib LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/ @@ -66,24 +63,23 @@ ifeq ($(TARGET), VTA_PYNQ_TARGET) LDFLAGS += -l:libdma.so endif -ifeq ($(TARGET), sim) +ifeq ($(VTA_TARGET), sim) VTA_LIB_SRC += $(wildcard src/sim/*.cc) endif VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC)) +all: lib/libvta.so + build/%.o: src/%.cc @mkdir -p $(@D) - $(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d + $(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d $(CXX) -c $(CFLAGS) -c $< -o $@ -lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ)) +lib/libvta.so: $(VTA_LIB_OBJ) @mkdir -p $(@D) $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) -lib/libvta_runtime.so: build/runtime.o - @mkdir -p $(@D) - $(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS) lint: pylint cpplint diff --git a/vta/README.md b/vta/README.md index 39d094cc54e6..5408f3b950b1 100644 --- a/vta/README.md +++ b/vta/README.md @@ -1,11 +1,25 @@ -Open Hardware/Software Stack for Vertical Deep Learning System Optimization -============================================== +VTA: Open, Modular, Deep Learning Accelerator Stack +=================================================== [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) -VTA is an open hardware/software co-design stack for deep learning systems systems. -It provides a customizable hardware accelerator template for deep learning inference workloads, -combined with a fully functional compiler stack built with TVM. +VTA(versatile tensor accelerator) is an open-source deep learning accelerator stack. +It is not just an open-source hardware, but is an end to end solution that includes +the entire software stack on top of VTA open-source hardware. + + +The key features include: + +- Generic, modular open-source hardware + - Streamlined workflow to deploy to FPGAs. + - Simulator support +- Driver and JIT runtime for both simulated backend and FPGA. +- End to end TVM stack integration + - Direct optimization and deploy models from deep learning frameworks via TVM stack. + - Customized and extendible TVM compiler backend + - Flexible RPC support to ease the deployment, you can program it with python :) + +VTA is part of our effort on [TVM Stack](http://www.tvmlang.org/). License ------- diff --git a/vta/make/config.mk b/vta/make/config.mk index e329dcf987b8..2bf25132e245 100644 --- a/vta/make/config.mk +++ b/vta/make/config.mk @@ -26,8 +26,8 @@ ADD_LDFLAGS= # the additional compile flags you want to add ADD_CFLAGS= -# the hardware target -TARGET = pynq +# the hardware target, can be [sim, pynq] +VTA_TARGET = pynq #--------------------- # VTA hardware parameters @@ -88,7 +88,8 @@ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_A VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" ) # Update ADD_CFLAGS -ADD_CFLAGS += \ +ADD_CFLAGS += + -DVTA_TARGET=$(VTA_TARGET)\ -DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \ -DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \ -DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \ diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index 693a4124f40b..80091f80d164 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -1,5 +1,6 @@ """TVM-based VTA Compiler Toolchain""" from __future__ import absolute_import as _abs +import sys from .environment import get_env, Environment @@ -10,5 +11,5 @@ from .rpc_client import reconfig_runtime, program_fpga from . import graph -except ImportError: +except (ImportError, RuntimeError): pass diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index 8ff2bbce2787..a59e66a564b2 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -3,14 +3,10 @@ from __future__ import absolute_import as _abs import os +import glob import copy - -try: - # Allow missing import in config mode. - import tvm - from . import intrin -except ImportError: - pass +import tvm +from . import intrin class DevContext(object): @@ -65,6 +61,45 @@ def get_task_qid(self, qid): return 1 if self.DEBUG_NO_SYNC else qid +class PkgConfig(object): + """Simple package config tool for VTA. + + This is used to provide runtime specific configurations. + """ + def __init__(self, env): + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + proj_root = os.path.abspath(os.path.join(curr_path, "../../")) + # include path + self.include_path = [ + "-I%s/include" % proj_root, + "-I%s/nnvm/tvm/include" % proj_root, + "-I%s/nnvm/tvm/dlpack/include" % proj_root, + "-I%s/nnvm/dmlc-core/include" % proj_root + ] + # List of source files that can be used to build standalone library. + self.lib_source = [] + self.lib_source += glob.glob("%s/src/*.cc" % proj_root) + self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET)) + # macro keys + self.macro_defs = [] + for key in env.cfg_keys: + self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key)))) + + if env.TARGET == "pynq": + self.ldflags = [ + "-L/usr/lib", + "-lsds_lib", + "-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/", + "-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/", + "-l:libdma.so"] + else: + self.ldflags = [] + + @property + def cflags(self): + return self.include_path + self.macro_defs + + class Environment(object): """Hareware configuration object. @@ -160,6 +195,7 @@ def __init__(self, cfg): self.mock_mode = False self._mock_env = None self._dev_ctx = None + self._pkg_config = None @property def dev(self): @@ -168,6 +204,13 @@ def dev(self): self._dev_ctx = DevContext(self) return self._dev_ctx + @property + def pkg_config(self): + """PkgConfig instance""" + if self._pkg_config is None: + self._pkg_config = PkgConfig(self) + return self._pkg_config + @property def mock(self): """A mock version of the Environment @@ -249,7 +292,7 @@ def mem_info_wgt_buffer(): head_address=None) @tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope) -def mem_info_out_buffer(): +def mem_info_acc_buffer(): spec = get_env() return tvm.make.node("MemoryInfo", unit_bits=spec.ACC_ELEM_BITS, @@ -265,6 +308,7 @@ def coproc_sync(op): "int32", "VTASynchronize", get_env().dev.command_handle, 1<<31) + @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push") def coproc_dep_push(op): return tvm.call_extern( @@ -272,6 +316,7 @@ def coproc_dep_push(op): get_env().dev.command_handle, op.args[0], op.args[1]) + @tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop") def coproc_dep_pop(op): return tvm.call_extern( @@ -288,7 +333,6 @@ def _init_env(): for k in Environment.cfg_keys: keys.add("VTA_" + k) - keys.add("TARGET") if not os.path.isfile(filename): raise RuntimeError( @@ -303,8 +347,9 @@ def _init_env(): val = line.split("=")[1].strip() if k.startswith("VTA_"): k = k[4:] + try: cfg[k] = int(val) - else: + except ValueError: cfg[k] = val return Environment(cfg) diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index ebaf15f8dc37..ca99416f7818 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -9,26 +9,46 @@ import os import ctypes import tvm +from tvm._ffi.base import c_str from tvm.contrib import rpc, cc +from ..environment import get_env + @tvm.register_func("tvm.contrib.rpc.server.start", override=True) def server_start(): - """callback when server starts.""" + """VTA RPC server extension.""" # pylint: disable=unused-variable curr_path = os.path.dirname( os.path.abspath(os.path.expanduser(__file__))) dll_path = os.path.abspath( - os.path.join(curr_path, "../../../lib/libvta_runtime.so")) + os.path.join(curr_path, "../../../lib/libvta.so")) runtime_dll = [] _load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module") - @tvm.register_func("tvm.contrib.rpc.server.load_module", override=True) - def load_module(file_name): + def load_vta_dll(): + """Try to load vta dll""" if not runtime_dll: runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL)) + logging.info("Loading VTA library: %s", dll_path) + return runtime_dll[0] + + @tvm.register_func("tvm.contrib.rpc.server.load_module", override=True) + def load_module(file_name): + load_vta_dll() return _load_module(file_name) + @tvm.register_func("device_api.ext_dev") + def ext_dev_callback(): + load_vta_dll() + return tvm.get_global_func("device_api.ext_dev")() + + @tvm.register_func("tvm.contrib.vta.init", override=True) + def program_fpga(file_name): + path = tvm.get_global_func("tvm.contrib.rpc.server.workpath")(file_name) + load_vta_dll().VTAProgram(c_str(path)) + logging.info("Program FPGA with %s", file_name) + @tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True) def server_shutdown(): if runtime_dll: @@ -47,17 +67,15 @@ def reconfig_runtime(cflags): if runtime_dll: raise RuntimeError("Can only reconfig in the beginning of session...") cflags = cflags.split() + env = get_env() cflags += ["-O2", "-std=c++11"] + cflags += env.pkg_config.include_path + ldflags = env.pkg_config.ldflags lib_name = dll_path - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - proj_root = os.path.abspath(os.path.join(curr_path, "../../../")) - runtime_source = os.path.join(proj_root, "src/runtime.cc") - cflags += ["-I%s/include" % proj_root] - cflags += ["-I%s/nnvm/tvm/include" % proj_root] - cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root] - cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root] - logging.info("Rebuild runtime dll with %s", str(cflags)) - cc.create_shared(lib_name, [runtime_source], cflags) + source = env.pkg_config.lib_source + logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s", + dll_path, str(cflags), str(source), str(ldflags)) + cc.create_shared(lib_name, source, cflags + ldflags) def main(): @@ -75,14 +93,6 @@ def main(): help="Report to RPC tracker") args = parser.parse_args() logging.basicConfig(level=logging.INFO) - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - proj_root = os.path.abspath(os.path.join(curr_path, "../../../")) - lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so")) - - libs = [] - for file_name in [lib_path]: - libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) - logging.info("Load additional library %s", file_name) if args.tracker: url, port = args.tracker.split(":") @@ -99,7 +109,6 @@ def main(): args.port_end, key=args.key, tracker_addr=tracker_addr) - server.libs += libs server.proc.join() if __name__ == "__main__": diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index bb436a1853a8..3505e49eeb04 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -10,7 +10,6 @@ def _load_lib(): curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) dll_path = [ os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")), - os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so")) ] runtime_dll = [] if not all(os.path.exists(f) for f in dll_path): diff --git a/vta/python/vta/testing/util.py b/vta/python/vta/testing/util.py index bbf6417a167e..402546c0efa1 100644 --- a/vta/python/vta/testing/util.py +++ b/vta/python/vta/testing/util.py @@ -15,10 +15,18 @@ def run(run_func): run_func : function(env, remote) """ env = get_env() - # run on simulator - if simulator.enabled(): + + # Run on local sim rpc if necessary + local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0")) + if local_rpc: env.TARGET = "sim" - run_func(env, rpc.LocalSession()) + remote = rpc.connect("localhost", local_rpc) + run_func(env, remote) + else: + # run on simulator + if simulator.enabled(): + env.TARGET = "sim" + run_func(env, rpc.LocalSession()) # Run on PYNQ if env variable exists pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None) diff --git a/vta/src/data_buffer.cc b/vta/src/data_buffer.cc deleted file mode 100644 index 99f959ad8c8b..000000000000 --- a/vta/src/data_buffer.cc +++ /dev/null @@ -1,44 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file data_buffer.cc - * \brief Buffer related API for VTA. - * \note Buffer API remains stable across VTA designes. - */ -#include "./data_buffer.h" - -void* VTABufferAlloc(size_t size) { - return vta::DataBuffer::Alloc(size); -} - -void VTABufferFree(void* buffer) { - vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); -} - -void VTABufferCopy(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - int kind_mask) { - vta::DataBuffer* from_buffer = nullptr; - vta::DataBuffer* to_buffer = nullptr; - - if (kind_mask & 2) { - from_buffer = vta::DataBuffer::FromHandle(from); - from = from_buffer->virt_addr(); - } - if (kind_mask & 1) { - to_buffer = vta::DataBuffer::FromHandle(to); - to = to_buffer->virt_addr(); - } - if (from_buffer) { - from_buffer->InvalidateCache(from_offset, size); - } - - memcpy(static_cast(to) + to_offset, - static_cast(from) + from_offset, - size); - if (to_buffer) { - to_buffer->FlushCache(to_offset, size); - } -} diff --git a/vta/src/data_buffer.h b/vta/src/data_buffer.h deleted file mode 100644 index fba46dc07efa..000000000000 --- a/vta/src/data_buffer.h +++ /dev/null @@ -1,90 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file data_buffer.h - * \brief VTA runtime internal data buffer structure. - */ -#ifndef VTA_DATA_BUFFER_H_ -#define VTA_DATA_BUFFER_H_ - -#include -#include -#include -#include - -namespace vta { - -/*! \brief Enable coherent access between VTA and CPU. */ -static const bool kBufferCoherent = true; - -/*! - * \brief Data buffer represents data on CMA. - */ -struct DataBuffer { - /*! \return Virtual address of the data. */ - void* virt_addr() const { - return data_; - } - /*! \return Physical address of the data. */ - uint32_t phy_addr() const { - return phy_addr_; - } - /*! - * \brief Invalidate the cache of given location in data buffer. - * \param offset The offset to the data. - * \param size The size of the data. - */ - void InvalidateCache(size_t offset, size_t size) { - if (!kBufferCoherent) { - VTAInvalidateCache(phy_addr_ + offset, size); - } - } - /*! - * \brief Invalidate the cache of certain location in data buffer. - * \param offset The offset to the data. - * \param size The size of the data. - */ - void FlushCache(size_t offset, size_t size) { - if (!kBufferCoherent) { - VTAFlushCache(phy_addr_ + offset, size); - } - } - /*! - * \brief Allocate a buffer of a given size. - * \param size The size of the buffer. - */ - static DataBuffer* Alloc(size_t size) { - void* data = VTAMemAlloc(size, 1); - assert(data != nullptr); - DataBuffer* buffer = new DataBuffer(); - buffer->data_ = data; - buffer->phy_addr_ = VTAMemGetPhyAddr(data); - return buffer; - } - /*! - * \brief Free the data buffer. - * \param buffer The buffer to be freed. - */ - static void Free(DataBuffer* buffer) { - VTAMemFree(buffer->data_); - delete buffer; - } - /*! - * \brief Create data buffer header from buffer ptr. - * \param buffer The buffer pointer. - * \return The corresponding data buffer header. - */ - static DataBuffer* FromHandle(const void* buffer) { - return const_cast( - reinterpret_cast(buffer)); - } - - private: - /*! \brief The internal data. */ - void* data_; - /*! \brief The physical address of the buffer, excluding header. */ - uint32_t phy_addr_; -}; - -} // namespace vta - -#endif // VTA_DATA_BUFFER_H_ diff --git a/vta/src/tvm/vta_device_api.cc b/vta/src/device_api.cc similarity index 75% rename from vta/src/tvm/vta_device_api.cc rename to vta/src/device_api.cc index e4671d8a0207..687ec625a02c 100644 --- a/vta/src/tvm/vta_device_api.cc +++ b/vta/src/device_api.cc @@ -1,15 +1,14 @@ /*! * Copyright (c) 2018 by Contributors - * \file vta_device_api.cc - * \brief VTA device API for TVM + * \file device_api.cc + * \brief TVM device API for VTA */ #include #include #include -#include -#include "../../nnvm/tvm/src/runtime/workspace_pool.h" +#include "../nnvm/tvm/src/runtime/workspace_pool.h" namespace tvm { @@ -26,7 +25,8 @@ class VTADeviceAPI final : public DeviceAPI { } void* AllocDataSpace(TVMContext ctx, - size_t size, size_t alignment, + size_t size, + size_t alignment, TVMType type_hint) final { return VTABufferAlloc(size); } @@ -84,22 +84,9 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { dmlc::ThreadLocalStore::Get()->FreeWorkspace(ctx, data); } -std::string VTARPCGetPath(const std::string& name) { - static const PackedFunc* f = - runtime::Registry::Get("tvm.contrib.rpc.server.workpath"); - CHECK(f != nullptr) << "require tvm.contrib.rpc.server.workpath"; - return (*f)(name); -} - -// Global functions that can be called -TVM_REGISTER_GLOBAL("tvm.contrib.vta.init") -.set_body([](TVMArgs args, TVMRetValue* rv) { - std::string path = VTARPCGetPath(args[0]); - VTAProgram(path.c_str()); - LOG(INFO) << "VTA initialization end with bistream " << path; - }); - -TVM_REGISTER_GLOBAL("device_api.ext_dev") +// Register device api with override. +static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ = +::tvm::runtime::Registry::Register("device_api.ext_dev", true) .set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VTADeviceAPI::Global().get(); *rv = static_cast(ptr); diff --git a/vta/src/runtime.cc b/vta/src/runtime.cc index da5109c141ee..c0de87fa339d 100644 --- a/vta/src/runtime.cc +++ b/vta/src/runtime.cc @@ -5,8 +5,6 @@ * * The runtime depends on specific instruction * stream spec as specified in hw_spec.h - * It is intended to be used as a dynamic library - * to enable hot swapping of hardware configurations. */ #include #include @@ -14,15 +12,87 @@ #include #include +#include #include #include #include #include -#include "./data_buffer.h" namespace vta { +/*! \brief Enable coherent access between VTA and CPU. */ +static const bool kBufferCoherent = true; + +/*! + * \brief Data buffer represents data on CMA. + */ +struct DataBuffer { + /*! \return Virtual address of the data. */ + void* virt_addr() const { + return data_; + } + /*! \return Physical address of the data. */ + uint32_t phy_addr() const { + return phy_addr_; + } + /*! + * \brief Invalidate the cache of given location in data buffer. + * \param offset The offset to the data. + * \param size The size of the data. + */ + void InvalidateCache(size_t offset, size_t size) { + if (!kBufferCoherent) { + VTAInvalidateCache(phy_addr_ + offset, size); + } + } + /*! + * \brief Invalidate the cache of certain location in data buffer. + * \param offset The offset to the data. + * \param size The size of the data. + */ + void FlushCache(size_t offset, size_t size) { + if (!kBufferCoherent) { + VTAFlushCache(phy_addr_ + offset, size); + } + } + /*! + * \brief Allocate a buffer of a given size. + * \param size The size of the buffer. + */ + static DataBuffer* Alloc(size_t size) { + void* data = VTAMemAlloc(size, 1); + CHECK(data != nullptr); + DataBuffer* buffer = new DataBuffer(); + buffer->data_ = data; + buffer->phy_addr_ = VTAMemGetPhyAddr(data); + return buffer; + } + /*! + * \brief Free the data buffer. + * \param buffer The buffer to be freed. + */ + static void Free(DataBuffer* buffer) { + VTAMemFree(buffer->data_); + delete buffer; + } + /*! + * \brief Create data buffer header from buffer ptr. + * \param buffer The buffer pointer. + * \return The corresponding data buffer header. + */ + static DataBuffer* FromHandle(const void* buffer) { + return const_cast( + reinterpret_cast(buffer)); + } + + private: + /*! \brief The internal data. */ + void* data_; + /*! \brief The physical address of the buffer, excluding header. */ + uint32_t phy_addr_; +}; + /*! * \brief Micro op kernel. * Contains functions to construct the kernel with prefix Push. @@ -1130,6 +1200,42 @@ class CommandQueue { } // namespace vta +void* VTABufferAlloc(size_t size) { + return vta::DataBuffer::Alloc(size); +} + +void VTABufferFree(void* buffer) { + vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); +} + +void VTABufferCopy(const void* from, + size_t from_offset, + void* to, + size_t to_offset, + size_t size, + int kind_mask) { + vta::DataBuffer* from_buffer = nullptr; + vta::DataBuffer* to_buffer = nullptr; + + if (kind_mask & 2) { + from_buffer = vta::DataBuffer::FromHandle(from); + from = from_buffer->virt_addr(); + } + if (kind_mask & 1) { + to_buffer = vta::DataBuffer::FromHandle(to); + to = to_buffer->virt_addr(); + } + if (from_buffer) { + from_buffer->InvalidateCache(from_offset, size); + } + + memcpy(static_cast(to) + to_offset, + static_cast(from) + from_offset, + size); + if (to_buffer) { + to_buffer->FlushCache(to_offset, size); + } +} VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); diff --git a/vta/tests/python/unittest/test_vta_insn.py b/vta/tests/python/unittest/test_vta_insn.py index 339d8d31e238..79a20be13de9 100644 --- a/vta/tests/python/unittest/test_vta_insn.py +++ b/vta/tests/python/unittest/test_vta_insn.py @@ -468,7 +468,22 @@ def _run(env, remote): vta.testing.run(_run) + +def test_runtime_array(): + def _run(env, remote): + n = 100 + ctx = remote.ext_dev(0) + x_np = np.random.randint( + 1, 10, size=(n, n, env.BATCH, env.BLOCK_OUT)).astype("int8") + x_nd = tvm.nd.array(x_np, ctx) + np.testing.assert_equal(x_np, x_nd.asnumpy()) + + vta.testing.run(_run) + + if __name__ == "__main__": + print("Array test") + test_runtime_array() print("Load/store test") test_save_load_out() print("Padded load test")