Skip to content

Commit

Permalink
[VTA][Relay] Relay Compilation + AutoTVM compatible operator librarie…
Browse files Browse the repository at this point in the history
…s for VTA (apache#3135)
  • Loading branch information
tmoreau89 authored and tqchen committed Jun 28, 2019
1 parent 0816033 commit dacd7b6
Show file tree
Hide file tree
Showing 34 changed files with 3,228 additions and 1,021 deletions.
2 changes: 1 addition & 1 deletion include/vta/driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ extern "C" {

/*! \brief Physically contiguous buffer size limit */
#ifndef VTA_MAX_XFER
#define VTA_MAX_XFER (1<<22)
#define VTA_MAX_XFER (1<<25)
#endif

/*! PAGE SIZE */
Expand Down
1 change: 1 addition & 0 deletions python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
"""VTA specific buildin for runtime."""
from __future__ import absolute_import as _abs

Expand Down
7 changes: 7 additions & 0 deletions python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,10 @@ def gemm(self):
"""GEMM intrinsic"""
return self.dev.gemm

@property
def target(self):
return tvm.target.vta(model=self.TARGET)

@property
def target_host(self):
"""The target host"""
Expand All @@ -243,6 +247,9 @@ def target_host(self):
return "llvm"
raise ValueError("Unknown target %s" % self.TARGET)

@property
def target_vta_cpu(self):
return tvm.target.arm_cpu(model=self.TARGET)

def get_env():
"""Get the current VTA Environment.
Expand Down
3 changes: 3 additions & 0 deletions python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def ext_dev_callback():

@tvm.register_func("tvm.contrib.vta.init", override=True)
def program_fpga(file_name):
from pynq import xlnk
# Reset xilinx driver
xlnk.Xlnk().xlnk_reset()
path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
env = get_env()
program_bitstream.bitstream_program(env.TARGET, path)
Expand Down
2 changes: 0 additions & 2 deletions python/vta/pkg_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def __init__(self, cfg, proj_root):
if self.target == "pynq":
self.ldflags = [
"-L/usr/lib",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
"-l:libcma.so"]
else:
self.ldflags = []
Expand Down
13 changes: 13 additions & 0 deletions python/vta/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,17 @@ def tsim_cycles():
"""
return tvm.get_global_func("tvm.vta.tsim.cycles")()

# debug flag to skip execution.
DEBUG_SKIP_EXEC = 1

def debug_mode(flag):
"""Set debug mode
Paramaters
----------
flag : int
The debug flag, 0 means clear all flags.
"""
tvm.get_global_func("vta.simulator.profiler_debug_mode")(flag)


LIBS = _load_lib()
28 changes: 19 additions & 9 deletions python/vta/testing/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import absolute_import as _abs

import os
from tvm import rpc
from tvm import rpc, autotvm
from ..environment import get_env
from . import simulator

Expand All @@ -42,7 +42,7 @@ def run(run_func):
# the port it's listening to, e.g. 9090
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
if local_rpc:
remote = rpc.connect("localhost", local_rpc)
remote = rpc.connect("127.0.0.1", local_rpc)
run_func(env, remote)
else:
# Make sure simulation library exists
Expand All @@ -54,12 +54,22 @@ def run(run_func):

elif env.TARGET == "pynq":

# Run on PYNQ if env variable exists
host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
if host and port:
remote = rpc.connect(host, port)
tracket_host = os.environ.get("TVM_TRACKER_HOST", None)
tracket_port = int(os.environ.get("TVM_TRACKER_PORT", None))
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)
pynq_port = int(os.environ.get("VTA_PYNQ_RPC_PORT", None))
# Run device from fleet node if env variables are defined
if tracket_host and tracket_port:
remote = autotvm.measure.request_remote(env.TARGET,
tracket_host,
tracket_port,
timeout=10000)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
# Next, run on PYNQ if env variables are defined
if pynq_host and pynq_port:
remote = rpc.connect(pynq_host, pynq_port)
run_func(env, remote)
else:
raise RuntimeError(
"Please set the VTA_PYNQ_RPC_HOST and VTA_PYNQ_RPC_PORT environment variables")
11 changes: 9 additions & 2 deletions python/vta/top/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
"""TVM TOPI connector, eventually most of these should go to TVM repo"""

from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from . import bitpack
from .graphpack import graph_pack
from . import op
from . import vta_conv2d
from . import arm_conv2d
from . import vta_dense

# NNVM is deprecated for VTA
# from . import nnvm_bitpack
# from .nnvm_graphpack import nnvm_graph_pack
# from . import nnvm_op
37 changes: 0 additions & 37 deletions python/vta/top/arm_conv2d.py

This file was deleted.

90 changes: 90 additions & 0 deletions python/vta/top/bitpack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=ungrouped-imports

"""Bit packing operators"""
from __future__ import absolute_import as _abs

import tvm
from topi import util

from tvm.relay.op.op import register_compute, register_schedule
from tvm.relay.op.op import register_pattern, OpPattern
from tvm.relay.op.op import schedule_injective

def bitpack(data, bits, pack_type="int8", name="bitpack"):
"""Packs lowest dimension into format needed by VTA
Parameters
----------
pack_axis : int
index of the axis to pack in data
bit_axis : int
index of axis to place bit axis in resulting packed data
Returns
-------
packed : Tensor
The packed tensor.
"""
shape_vec = list(data.shape)
if pack_type == 'int8':
data_width = 8
elif pack_type == 'int16':
data_width = 16
elif pack_type == 'int32':
data_width = 32
else:
raise RuntimeError("Unknown pack type %s" % pack_type)
assert data_width % bits == 0
lanes = data_width // bits

# Data must be in multiples of the data_width
assert util.get_const_int(shape_vec[-1]) % lanes == 0, "Not a multiple of word size"
shape_vec[-1] = shape_vec[-1] // lanes
oshape = tuple(shape_vec)

def _bitpack(*indices):
ret = None
mask = tvm.const((1 << bits) - 1, pack_type)
for k in range(lanes):
idx = list(indices)
idx[-1] = idx[-1] * lanes + k
elem = data(*idx).astype(pack_type)
if k == 0:
ret = elem & mask
else:
val = (elem & mask) << tvm.const(k * bits, pack_type)
ret = ret | val
return ret

return tvm.compute(
oshape, _bitpack, name=name, tag='bitpack')


@register_compute("bitpack", level=15)
def compute_bitpack(attrs, inputs):
lanes = attrs.lanes
dtype = inputs[0].dtype
assert dtype == "int8"
width = 8
assert width % lanes == 0
bits = 8 // lanes
return bitpack(inputs[0], bits, dtype)

register_schedule("bitpack", schedule_injective)
register_pattern("bitpack", OpPattern.INJECTIVE)
Loading

0 comments on commit dacd7b6

Please sign in to comment.