forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[AUTOTVM] TOPI integration for ARM CPU (apache#1487)
- Loading branch information
1 parent
21b3c07
commit 0354c31
Showing
78 changed files
with
3,504 additions
and
2,306 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -188,3 +188,6 @@ build* | |
|
||
# Jetbrain | ||
.idea | ||
|
||
# tmp file | ||
.nfs* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Performance Benchmark | ||
|
||
## Results | ||
|
||
See results on wiki page https://github.com/dmlc/tvm/wiki/Benchmark | ||
|
||
## How to Reproduce | ||
|
||
### ARM CPU | ||
We use RPC infrastructure in TVM to make device management easy. So you need to use it for reproducing benchmark results. | ||
|
||
1. Start an RPC Tracker on the host machine | ||
```bash | ||
python3 -m tvm.exec.rpc_tracker | ||
``` | ||
|
||
2. Register devices to the tracker | ||
* For Linux device | ||
* Build tvm runtime on your device [Help](https://docs.tvm.ai/tutorials/nnvm/deploy_model_on_rasp.html#build-tvm-runtime-on-device) | ||
* Register your device to tracker by | ||
```bash | ||
python3 -m tvm.exec.rpc_sever --tracker=[HOST_IP]:9190 --key=[DEVICE_KEY] | ||
``` | ||
replace `[HOST_IP]` with the IP address of the host machine, `[DEVICE_KEY]` with the name of device. | ||
|
||
E.g. Here is an example command for RK3399, | ||
`python3 -m tvm.exec.rpc_sever --tracker=10.77.1.123:9190 --key=rk3399`, where 10.77.1.123 is the IP address of the tracker. | ||
|
||
* For Android device | ||
* Build and install tvm RPC apk on your device [Help](https://github.com/dmlc/tvm/tree/master/apps/android_rpc). | ||
Make sure you can pass the android rpc test. Then you have alreadly known how to register. | ||
|
||
3. Verify the device registration | ||
We can query all registered devices by | ||
```bash | ||
python3 -m tvm.exec.query_rpc_tracker | ||
``` | ||
You should be able to find your devices in `Queue Status`. Make sure the registration is correct before going ahead. | ||
|
||
For our test environment, one sample output can be | ||
```bash | ||
Queue Status | ||
------------------------------ | ||
key free pending | ||
------------------------------ | ||
mate10pro 1 0 | ||
p20pro 2 0 | ||
pixel2 2 0 | ||
rk3399 2 0 | ||
rasp3b 8 0 | ||
``` | ||
|
||
4. Run benchmark | ||
We did auto-tuning for Huawei P20/Mate10 Pro, Google Pixel2, Raspberry Pi3 and Firefly-RK3399, | ||
and release pre-tuned parameters in [this repo](https://github.com/uwsaml/tvm-distro). | ||
During compilation, TVM will download these operator parameters automatically. | ||
|
||
```bash | ||
python3 arm_cpu_imagenet_bench.py --device rasp3b --rpc-key rasp3b | ||
python3 arm_cpu_imagenet_bench.py --device rk3399 --rpc-key rk3399 | ||
python3 arm_cpu_imagenet_bench.py --device pixel2 --rpc-key pixel2 | ||
python3 arm_cpu_imagenet_bench.py --device p20pro --rpc-key p20pro | ||
python3 arm_cpu_imagenet_bench.py --device mate10pro --rpc-key mate10pro | ||
``` | ||
|
||
If your device has a same SoC of the above device, you can reuse these parameters | ||
(e.g. use `llvm -device=arm_cpu -mode=rk3399 -target=aarch64-linux-gnu` as target). | ||
Otherwise, you need to tune for your own device, please follow this | ||
[tutorial](https://docs.tvm.ai/tutorials/autotvm/tune_nnvm_arm.html). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Benchmark script for performance on ARM CPU. | ||
see README.md for the usage and results of this script. | ||
""" | ||
|
||
import argparse | ||
import time | ||
|
||
import numpy as np | ||
|
||
import nnvm.testing | ||
import nnvm.compiler | ||
import tvm | ||
from tvm import autotvm | ||
from tvm.contrib.util import tempdir | ||
import tvm.contrib.graph_runtime as runtime | ||
|
||
def get_network(name, batch_size): | ||
"""Get the symbol definition and random weight of a network""" | ||
input_shape = (batch_size, 3, 224, 224) | ||
output_shape = (batch_size, 1000) | ||
|
||
if name == 'resnet-18': | ||
net, params = nnvm.testing.resnet.get_workload(num_layers=18, | ||
batch_size=batch_size, image_shape=(3, 224, 224)) | ||
elif name == 'mobilenet': | ||
net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) | ||
elif name == 'squeezenet v1.1': | ||
net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, | ||
version='1.1') | ||
elif name == 'vgg-16': | ||
net, params = nnvm.testing.vgg.get_workload(batch_size=batch_size, num_layers=16) | ||
else: | ||
raise RuntimeError("Unsupported network: " + name) | ||
|
||
return net, params, input_shape, output_shape | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--network", type=str, choices=['resnet-18', 'mobilenet', 'squeezenet v1.1', 'vgg-16']) | ||
parser.add_argument("--device", type=str, required=True, choices=['rk3399', 'mate10', 'mate10pro', 'p20', 'p20pro', | ||
'pixel2', 'rasp3b', 'pynq']) | ||
parser.add_argument("--host", type=str, default='localhost') | ||
parser.add_argument("--port", type=int, default=9190) | ||
parser.add_argument("--rpc-key", type=str, required=True) | ||
parser.add_argument("--number", type=int, default=6) | ||
args = parser.parse_args() | ||
|
||
dtype = 'float32' | ||
|
||
if args.network is None: | ||
networks = ['squeezenet v1.1', 'mobilenet', 'resnet-18', 'vgg-16'] | ||
else: | ||
networks = [args.network] | ||
|
||
target = tvm.target.arm_cpu(model=args.device) | ||
|
||
# connect to remote device | ||
tracker = tvm.rpc.connect_tracker(args.host, args.port) | ||
remote = tracker.request(args.rpc_key) | ||
|
||
print("--------------------------------------------------") | ||
print("%-20s %-20s" % ("Network Name", "Mean Inference Time (std dev)")) | ||
print("--------------------------------------------------") | ||
for network in networks: | ||
net, params, input_shape, output_shape = get_network(network, batch_size=1) | ||
|
||
with nnvm.compiler.build_config(opt_level=2, add_pass=['AlterOpLayout']): | ||
graph, lib, params = nnvm.compiler.build( | ||
net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) | ||
|
||
tmp = tempdir() | ||
if 'android' in str(target): | ||
from tvm.contrib import ndk | ||
filename = "%s.so" % network | ||
lib.export_library(tmp.relpath(filename), ndk.create_shared) | ||
else: | ||
filename = "%s.tar" % network | ||
lib.export_library(tmp.relpath(filename)) | ||
|
||
# upload library and params | ||
ctx = remote.context(str(target), 0) | ||
remote.upload(tmp.relpath(filename)) | ||
rparams = {k: tvm.nd.array(v, ctx) for k, v in params.items()} | ||
|
||
rlib = remote.load_module(filename) | ||
module = runtime.create(graph, rlib, ctx) | ||
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) | ||
module.set_input('data', data_tvm) | ||
module.set_input(**rparams) | ||
|
||
# evaluate | ||
ftimer = module.module.time_evaluator("run", ctx, number=args.number, repeat=3) | ||
prof_res = np.array(ftimer().results) * 1000 # multiply 1000 for converting to millisecond | ||
print("%-20s %-19s (%s)" % (network, "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res))) | ||
|
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.