Skip to content

Commit

Permalink
[Tutorial] mxnet (apache#47)
Browse files Browse the repository at this point in the history
* [Tutorial] mxnet

update

add from_gluon

add to __init__

fix tutorial and from_gluon

fix doc lint

merge from_mxnet

fix

fix

fix tutorial

fix

fix header

* fix tutorial

* fix data

* fix
  • Loading branch information
zhreshold authored and tqchen committed May 26, 2018
1 parent ba208f3 commit fd01043
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 18 deletions.
46 changes: 28 additions & 18 deletions nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,6 @@ def _from_mxnet_impl(symbol, graph):
nnvm.sym.Symbol
Converted symbol
"""
try:
from mxnet import sym as mx_sym # pylint: disable=import-self
except ImportError as e:
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))

if not isinstance(symbol, mx_sym.Symbol):
raise ValueError("Provided {}, while MXNet symbol is expected", type(symbol))

if _is_mxnet_group_symbol(symbol):
return [_from_mxnet_impl(s, graph) for s in symbol]

Expand Down Expand Up @@ -294,7 +286,7 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
Parameters
----------
symbol : mxnet.Symbol
symbol : mxnet.Symbol or mxnet.gluon.HybridBlock
MXNet symbol
arg_params : dict of str to mx.NDArray
Expand All @@ -305,18 +297,36 @@ def from_mxnet(symbol, arg_params=None, aux_params=None):
Returns
-------
net: nnvm.Symbol
sym : nnvm.Symbol
Compatible nnvm symbol
params : dict of str to tvm.NDArray
The parameter dict to be used by nnvm
"""
sym = _from_mxnet_impl(symbol, {})
params = {}
arg_params = arg_params if arg_params else {}
aux_params = aux_params if aux_params else {}
for k, v in arg_params.items():
params[k] = tvm.nd.array(v.asnumpy())
for k, v in aux_params.items():
params[k] = tvm.nd.array(v.asnumpy())
try:
import mxnet as mx # pylint: disable=import-self
except ImportError as e:
raise ImportError('{}. MXNet is required to parse symbols.'.format(e))

if isinstance(symbol, mx.sym.Symbol):
sym = _from_mxnet_impl(symbol, {})
params = {}
arg_params = arg_params if arg_params else {}
aux_params = aux_params if aux_params else {}
for k, v in arg_params.items():
params[k] = tvm.nd.array(v.asnumpy())
for k, v in aux_params.items():
params[k] = tvm.nd.array(v.asnumpy())
elif isinstance(symbol, mx.gluon.HybridBlock):
data = mx.sym.Variable('data')
sym = symbol(data)
sym = _from_mxnet_impl(sym, {})
params = {}
for k, v in symbol.collect_params().items():
params[k] = tvm.nd.array(v.data().asnumpy())
elif isinstance(symbol, mx.gluon.Block):
raise NotImplementedError("The dynamic Block is not supported yet.")
else:
msg = "mxnet.Symbol or gluon.HybridBlock expected, got {}".format(type(symbol))
raise ValueError(msg)
return sym, params
1 change: 1 addition & 0 deletions nnvm/python/nnvm/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. "Identity Mappings in Deep Residual Networks"
'''
# pylint: disable=unused-argument
import numpy as np
from .. import symbol as sym
from . utils import create_workload
Expand Down
114 changes: 114 additions & 0 deletions nnvm/tutorials/from_mxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Compiling MXNet Models with NNVM
================================
**Author**: `Joshua Z. Zhang <https://zhreshold.github.io/>`_
This article is an introductory tutorial to deploy mxnet models with NNVM.
For us to begin with, mxnet module is required to be installed.
A quick solution is
```
pip install mxnet --user
```
or please refer to offical installation guide.
https://mxnet.incubator.apache.org/versions/master/install/index.html
"""
# some standard imports
import mxnet as mx
import nnvm
import tvm
import numpy as np

######################################################################
# Download Resnet18 model from Gluon Model Zoo
# ---------------------------------------------
# In this section, we download a pretrained imagenet model and classify an image.
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download
import Image
from matplotlib import pyplot as plt
block = get_model('resnet18_v1', pretrained=True)
img_name = 'cat.jpg'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
'596b27d23537e5a1b5751d2b0481ef172f58b539/',
'imagenet1000_clsid_to_human.txt'])
synset_name = 'synset.txt'
download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name)
download(synset_url, synset_name)
with open(synset_name) as f:
synset = eval(f.read())
image = Image.open(img_name).resize((224, 224))
plt.imshow(image)
plt.show()

def transform_image(image):
image = np.array(image) - np.array([123., 117., 104.])
image /= np.array([58.395, 57.12, 57.375])
image = image.transpose((2, 0, 1))
image = image[np.newaxis, :]
return image

x = transform_image(image)
print('x', x.shape)

######################################################################
# Compile the Graph
# -----------------
# Now we would like to port the Gluon model to a portable computational graph.
# It's as easy as several lines.
# We support MXNet static graph(symbol) and HybridBlock in mxnet.gluon
sym, params = nnvm.frontend.from_mxnet(block)
# we want a probability so add a softmax operator
sym = nnvm.sym.softmax(sym)

######################################################################
# now compile the graph
import nnvm.compiler
target = 'cuda'
shape_dict = {'data': x.shape}
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)

######################################################################
# Execute the portable graph on TVM
# ---------------------------------
# Now, we would like to reproduce the same forward computation using TVM.
from tvm.contrib import graph_runtime
ctx = tvm.gpu(0)
dtype = 'float32'
m = graph_runtime.create(graph, lib, ctx)
# set inputs
m.set_input('data', tvm.nd.array(x.astype(dtype)))
m.set_input(**params)
# execute
m.run()
# get outputs
tvm_output = m.get_output(0, tvm.nd.empty((1000,), dtype))
top1 = np.argmax(tvm_output)
print('TVM prediction top-1:', top1, synset[top1])

######################################################################
# Use MXNet symbol with pretrained weights
# ----------------------------------------
# MXNet often use `arg_prams` and `aux_params` to store network parameters
# separately, here we show how to use these weights with existing API
def block2symbol(block):
data = mx.sym.Variable('data')
sym = block(data)
args = {}
auxs = {}
for k, v in block.collect_params().items():
args[k] = mx.nd.array(v.data().asnumpy())
return sym, args, auxs
mx_sym, args, auxs = block2symbol(block)
# usually we would save/load it as checkpoint
mx.model.save_checkpoint('resnet18_v1', 0, mx_sym, args, auxs)
# there are 'resnet18_v1-0000.params' and 'resnet18_v1-symbol.json' on disk

######################################################################
# for a normal mxnet model, we start from here
mx_sym, args, auxs = mx.model.load_checkpoint('resnet18_v1', 0)
# now we use the same API to get NNVM compatible symbol
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(mx_sym, args, auxs)
# repeat the same steps to run this model using TVM

0 comments on commit fd01043

Please sign in to comment.